pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +83 -19
- pixeltable/_query.py +1444 -0
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +7 -4
- pixeltable/catalog/catalog.py +2394 -119
- pixeltable/catalog/column.py +225 -104
- pixeltable/catalog/dir.py +38 -9
- pixeltable/catalog/globals.py +53 -34
- pixeltable/catalog/insertable_table.py +265 -115
- pixeltable/catalog/path.py +80 -17
- pixeltable/catalog/schema_object.py +28 -43
- pixeltable/catalog/table.py +1270 -677
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +1270 -751
- pixeltable/catalog/table_version_handle.py +109 -0
- pixeltable/catalog/table_version_path.py +137 -42
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +251 -134
- pixeltable/config.py +215 -0
- pixeltable/env.py +736 -285
- pixeltable/exceptions.py +26 -2
- pixeltable/exec/__init__.py +7 -2
- pixeltable/exec/aggregation_node.py +39 -21
- pixeltable/exec/cache_prefetch_node.py +87 -109
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +25 -28
- pixeltable/exec/data_row_batch.py +11 -46
- pixeltable/exec/exec_context.py +26 -11
- pixeltable/exec/exec_node.py +35 -27
- pixeltable/exec/expr_eval/__init__.py +3 -0
- pixeltable/exec/expr_eval/evaluators.py +365 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
- pixeltable/exec/expr_eval/globals.py +200 -0
- pixeltable/exec/expr_eval/row_buffer.py +74 -0
- pixeltable/exec/expr_eval/schedulers.py +413 -0
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +35 -27
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +44 -29
- pixeltable/exec/sql_node.py +414 -115
- pixeltable/exprs/__init__.py +8 -5
- pixeltable/exprs/arithmetic_expr.py +79 -45
- pixeltable/exprs/array_slice.py +5 -5
- pixeltable/exprs/column_property_ref.py +40 -26
- pixeltable/exprs/column_ref.py +254 -61
- pixeltable/exprs/comparison.py +14 -9
- pixeltable/exprs/compound_predicate.py +9 -10
- pixeltable/exprs/data_row.py +213 -72
- pixeltable/exprs/expr.py +270 -104
- pixeltable/exprs/expr_dict.py +6 -5
- pixeltable/exprs/expr_set.py +20 -11
- pixeltable/exprs/function_call.py +383 -284
- pixeltable/exprs/globals.py +18 -5
- pixeltable/exprs/in_predicate.py +7 -7
- pixeltable/exprs/inline_expr.py +37 -37
- pixeltable/exprs/is_null.py +8 -4
- pixeltable/exprs/json_mapper.py +120 -54
- pixeltable/exprs/json_path.py +90 -60
- pixeltable/exprs/literal.py +61 -16
- pixeltable/exprs/method_ref.py +7 -6
- pixeltable/exprs/object_ref.py +19 -8
- pixeltable/exprs/row_builder.py +238 -75
- pixeltable/exprs/rowid_ref.py +53 -15
- pixeltable/exprs/similarity_expr.py +65 -50
- pixeltable/exprs/sql_element_cache.py +5 -5
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/exprs/type_cast.py +25 -13
- pixeltable/exprs/variable.py +2 -2
- pixeltable/func/__init__.py +9 -5
- pixeltable/func/aggregate_function.py +197 -92
- pixeltable/func/callable_function.py +119 -35
- pixeltable/func/expr_template_function.py +101 -48
- pixeltable/func/function.py +375 -62
- pixeltable/func/function_registry.py +20 -19
- pixeltable/func/globals.py +6 -5
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +151 -35
- pixeltable/func/signature.py +178 -49
- pixeltable/func/tools.py +164 -0
- pixeltable/func/udf.py +176 -53
- pixeltable/functions/__init__.py +44 -4
- pixeltable/functions/anthropic.py +226 -47
- pixeltable/functions/audio.py +148 -11
- pixeltable/functions/bedrock.py +137 -0
- pixeltable/functions/date.py +188 -0
- pixeltable/functions/deepseek.py +113 -0
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +72 -20
- pixeltable/functions/gemini.py +249 -0
- pixeltable/functions/globals.py +208 -53
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1088 -95
- pixeltable/functions/image.py +155 -84
- pixeltable/functions/json.py +8 -11
- pixeltable/functions/llama_cpp.py +31 -19
- pixeltable/functions/math.py +169 -0
- pixeltable/functions/mistralai.py +50 -75
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +29 -36
- pixeltable/functions/openai.py +548 -160
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +15 -14
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +310 -85
- pixeltable/functions/timestamp.py +37 -19
- pixeltable/functions/together.py +77 -120
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +7 -2
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1528 -117
- pixeltable/functions/vision.py +26 -26
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +19 -10
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/functions/yolox.py +112 -0
- pixeltable/globals.py +716 -236
- pixeltable/index/__init__.py +3 -1
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +32 -22
- pixeltable/index/embedding_index.py +155 -92
- pixeltable/io/__init__.py +12 -7
- pixeltable/io/datarows.py +140 -0
- pixeltable/io/external_store.py +83 -125
- pixeltable/io/fiftyone.py +24 -33
- pixeltable/io/globals.py +47 -182
- pixeltable/io/hf_datasets.py +96 -127
- pixeltable/io/label_studio.py +171 -156
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +136 -115
- pixeltable/io/parquet.py +40 -153
- pixeltable/io/table_data_conduit.py +702 -0
- pixeltable/io/utils.py +100 -0
- pixeltable/iterators/__init__.py +8 -4
- pixeltable/iterators/audio.py +207 -0
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +144 -87
- pixeltable/iterators/image.py +17 -38
- pixeltable/iterators/string.py +15 -12
- pixeltable/iterators/video.py +523 -127
- pixeltable/metadata/__init__.py +33 -8
- pixeltable/metadata/converters/convert_10.py +2 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_15.py +15 -11
- pixeltable/metadata/converters/convert_16.py +4 -5
- pixeltable/metadata/converters/convert_17.py +4 -5
- pixeltable/metadata/converters/convert_18.py +4 -6
- pixeltable/metadata/converters/convert_19.py +6 -9
- pixeltable/metadata/converters/convert_20.py +3 -6
- pixeltable/metadata/converters/convert_21.py +6 -8
- pixeltable/metadata/converters/convert_22.py +3 -2
- pixeltable/metadata/converters/convert_23.py +33 -0
- pixeltable/metadata/converters/convert_24.py +55 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/convert_26.py +23 -0
- pixeltable/metadata/converters/convert_27.py +29 -0
- pixeltable/metadata/converters/convert_28.py +13 -0
- pixeltable/metadata/converters/convert_29.py +110 -0
- pixeltable/metadata/converters/convert_30.py +63 -0
- pixeltable/metadata/converters/convert_31.py +11 -0
- pixeltable/metadata/converters/convert_32.py +15 -0
- pixeltable/metadata/converters/convert_33.py +17 -0
- pixeltable/metadata/converters/convert_34.py +21 -0
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +44 -18
- pixeltable/metadata/notes.py +21 -0
- pixeltable/metadata/schema.py +185 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +616 -225
- pixeltable/share/__init__.py +3 -0
- pixeltable/share/packager.py +797 -0
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +349 -0
- pixeltable/store.py +398 -232
- pixeltable/type_system.py +730 -267
- pixeltable/utils/__init__.py +40 -0
- pixeltable/utils/arrow.py +201 -29
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +26 -27
- pixeltable/utils/code.py +4 -4
- pixeltable/utils/console_output.py +46 -0
- pixeltable/utils/coroutine.py +24 -0
- pixeltable/utils/dbms.py +92 -0
- pixeltable/utils/description_helper.py +11 -12
- pixeltable/utils/documents.py +60 -61
- pixeltable/utils/exception_handler.py +36 -0
- pixeltable/utils/filecache.py +38 -22
- pixeltable/utils/formatter.py +88 -51
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +14 -13
- pixeltable/utils/iceberg.py +13 -0
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +20 -20
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +32 -5
- pixeltable/utils/system.py +30 -0
- pixeltable/utils/transactional_directory.py +4 -3
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -36
- pixeltable/catalog/path_dict.py +0 -141
- pixeltable/dataframe.py +0 -894
- pixeltable/exec/expr_eval_node.py +0 -232
- pixeltable/ext/__init__.py +0 -14
- pixeltable/ext/functions/__init__.py +0 -8
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/ext/functions/yolox.py +0 -157
- pixeltable/tool/create_test_db_dump.py +0 -311
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable/utils/media_store.py +0 -76
- pixeltable/utils/s3.py +0 -16
- pixeltable-0.2.26.dist-info/METADATA +0 -400
- pixeltable-0.2.26.dist-info/RECORD +0 -156
- pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Pixeltable
|
|
2
|
+
Pixeltable UDFs
|
|
3
3
|
that wrap various models from the Hugging Face `transformers` package.
|
|
4
4
|
|
|
5
5
|
These UDFs will cause Pixeltable to invoke the relevant models locally. In order to use them, you must
|
|
@@ -7,16 +7,22 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
|
|
|
7
7
|
UDFs).
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from typing import Any, Callable,
|
|
10
|
+
from typing import Any, Callable, Literal, TypeVar
|
|
11
11
|
|
|
12
|
+
import av
|
|
13
|
+
import numpy as np
|
|
12
14
|
import PIL.Image
|
|
13
15
|
|
|
14
16
|
import pixeltable as pxt
|
|
15
|
-
import pixeltable.env as env
|
|
16
17
|
import pixeltable.exceptions as excs
|
|
18
|
+
import pixeltable.type_system as ts
|
|
19
|
+
from pixeltable import env
|
|
17
20
|
from pixeltable.func import Batch
|
|
18
21
|
from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
|
|
19
22
|
from pixeltable.utils.code import local_public_names
|
|
23
|
+
from pixeltable.utils.local_store import TempStore
|
|
24
|
+
|
|
25
|
+
T = TypeVar('T')
|
|
20
26
|
|
|
21
27
|
|
|
22
28
|
@pxt.udf(batch_size=32)
|
|
@@ -46,12 +52,11 @@ def sentence_transformer(
|
|
|
46
52
|
Add a computed column that applies the model `all-mpnet-base-2` to an existing Pixeltable column `tbl.sentence`
|
|
47
53
|
of the table `tbl`:
|
|
48
54
|
|
|
49
|
-
>>> tbl
|
|
55
|
+
>>> tbl.add_computed_column(result=sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2'))
|
|
50
56
|
"""
|
|
51
57
|
env.Env.get().require_package('sentence_transformers')
|
|
52
58
|
device = resolve_torch_device('auto')
|
|
53
|
-
import
|
|
54
|
-
from sentence_transformers import SentenceTransformer # type: ignore
|
|
59
|
+
from sentence_transformers import SentenceTransformer
|
|
55
60
|
|
|
56
61
|
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
57
62
|
model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
|
|
@@ -62,21 +67,17 @@ def sentence_transformer(
|
|
|
62
67
|
|
|
63
68
|
|
|
64
69
|
@sentence_transformer.conditional_return_type
|
|
65
|
-
def _(model_id: str) ->
|
|
66
|
-
|
|
67
|
-
from sentence_transformers import SentenceTransformer
|
|
70
|
+
def _(model_id: str) -> ts.ArrayType:
|
|
71
|
+
from sentence_transformers import SentenceTransformer
|
|
68
72
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
except ImportError:
|
|
72
|
-
return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
|
|
73
|
+
model = _lookup_model(model_id, SentenceTransformer)
|
|
74
|
+
return ts.ArrayType((model.get_sentence_embedding_dimension(),), dtype=ts.FloatType(), nullable=False)
|
|
73
75
|
|
|
74
76
|
|
|
75
77
|
@pxt.udf
|
|
76
78
|
def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
|
|
77
79
|
env.Env.get().require_package('sentence_transformers')
|
|
78
80
|
device = resolve_torch_device('auto')
|
|
79
|
-
import torch
|
|
80
81
|
from sentence_transformers import SentenceTransformer
|
|
81
82
|
|
|
82
83
|
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
@@ -111,13 +112,12 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
111
112
|
Add a computed column that applies the model `ms-marco-MiniLM-L-4-v2` to the sentences in
|
|
112
113
|
columns `tbl.sentence1` and `tbl.sentence2`:
|
|
113
114
|
|
|
114
|
-
>>> tbl
|
|
115
|
-
|
|
116
|
-
|
|
115
|
+
>>> tbl.add_computed_column(result=sentence_transformer(
|
|
116
|
+
... tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
|
|
117
|
+
... ))
|
|
117
118
|
"""
|
|
118
119
|
env.Env.get().require_package('sentence_transformers')
|
|
119
120
|
device = resolve_torch_device('auto')
|
|
120
|
-
import torch
|
|
121
121
|
from sentence_transformers import CrossEncoder
|
|
122
122
|
|
|
123
123
|
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
@@ -132,7 +132,6 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
132
132
|
def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
|
|
133
133
|
env.Env.get().require_package('sentence_transformers')
|
|
134
134
|
device = resolve_torch_device('auto')
|
|
135
|
-
import torch
|
|
136
135
|
from sentence_transformers import CrossEncoder
|
|
137
136
|
|
|
138
137
|
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
@@ -144,9 +143,9 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
|
|
|
144
143
|
|
|
145
144
|
|
|
146
145
|
@pxt.udf(batch_size=32)
|
|
147
|
-
def
|
|
146
|
+
def clip(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
148
147
|
"""
|
|
149
|
-
Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
|
|
148
|
+
Computes a CLIP embedding for the specified text or image. `model_id` should be a reference to a pretrained
|
|
150
149
|
[CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
|
|
151
150
|
|
|
152
151
|
__Requirements:__
|
|
@@ -164,12 +163,16 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), px
|
|
|
164
163
|
Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
|
|
165
164
|
Pixeltable column `tbl.text` of the table `tbl`:
|
|
166
165
|
|
|
167
|
-
>>> tbl
|
|
166
|
+
>>> tbl.add_computed_column(
|
|
167
|
+
... result=clip(tbl.text, model_id='openai/clip-vit-base-patch32')
|
|
168
|
+
... )
|
|
169
|
+
|
|
170
|
+
The same would work with an image column `tbl.image` in place of `tbl.text`.
|
|
168
171
|
"""
|
|
169
172
|
env.Env.get().require_package('transformers')
|
|
170
173
|
device = resolve_torch_device('auto')
|
|
171
174
|
import torch
|
|
172
|
-
from transformers import CLIPModel, CLIPProcessor
|
|
175
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
173
176
|
|
|
174
177
|
model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
|
|
175
178
|
processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
|
|
@@ -181,29 +184,8 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), px
|
|
|
181
184
|
return [embeddings[i] for i in range(embeddings.shape[0])]
|
|
182
185
|
|
|
183
186
|
|
|
184
|
-
@
|
|
185
|
-
def
|
|
186
|
-
"""
|
|
187
|
-
Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
|
|
188
|
-
[CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
|
|
189
|
-
|
|
190
|
-
__Requirements:__
|
|
191
|
-
|
|
192
|
-
- `pip install torch transformers`
|
|
193
|
-
|
|
194
|
-
Args:
|
|
195
|
-
image: The image to embed.
|
|
196
|
-
model_id: The pretrained model to use for the embedding.
|
|
197
|
-
|
|
198
|
-
Returns:
|
|
199
|
-
An array containing the output of the embedding model.
|
|
200
|
-
|
|
201
|
-
Examples:
|
|
202
|
-
Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
|
|
203
|
-
Pixeltable column `image` of the table `tbl`:
|
|
204
|
-
|
|
205
|
-
>>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
|
|
206
|
-
"""
|
|
187
|
+
@clip.overload
|
|
188
|
+
def _(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
207
189
|
env.Env.get().require_package('transformers')
|
|
208
190
|
device = resolve_torch_device('auto')
|
|
209
191
|
import torch
|
|
@@ -219,25 +201,17 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Arr
|
|
|
219
201
|
return [embeddings[i] for i in range(embeddings.shape[0])]
|
|
220
202
|
|
|
221
203
|
|
|
222
|
-
@
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
try:
|
|
226
|
-
from transformers import CLIPModel
|
|
204
|
+
@clip.conditional_return_type
|
|
205
|
+
def _(model_id: str) -> ts.ArrayType:
|
|
206
|
+
from transformers import CLIPModel
|
|
227
207
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
except ImportError:
|
|
231
|
-
return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
|
|
208
|
+
model = _lookup_model(model_id, CLIPModel.from_pretrained)
|
|
209
|
+
return ts.ArrayType((model.config.projection_dim,), dtype=ts.FloatType(), nullable=False)
|
|
232
210
|
|
|
233
211
|
|
|
234
212
|
@pxt.udf(batch_size=4)
|
|
235
213
|
def detr_for_object_detection(
|
|
236
|
-
image: Batch[PIL.Image.Image],
|
|
237
|
-
*,
|
|
238
|
-
model_id: str,
|
|
239
|
-
threshold: float = 0.5,
|
|
240
|
-
revision: str = 'no_timm',
|
|
214
|
+
image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5, revision: str = 'no_timm'
|
|
241
215
|
) -> Batch[dict]:
|
|
242
216
|
"""
|
|
243
217
|
Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
|
|
@@ -268,11 +242,11 @@ def detr_for_object_detection(
|
|
|
268
242
|
Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
|
|
269
243
|
Pixeltable column `image` of the table `tbl`:
|
|
270
244
|
|
|
271
|
-
>>> tbl
|
|
245
|
+
>>> tbl.add_computed_column(detections=detr_for_object_detection(
|
|
272
246
|
... tbl.image,
|
|
273
247
|
... model_id='facebook/detr-resnet-50',
|
|
274
248
|
... threshold=0.8
|
|
275
|
-
... )
|
|
249
|
+
... ))
|
|
276
250
|
"""
|
|
277
251
|
env.Env.get().require_package('transformers')
|
|
278
252
|
device = resolve_torch_device('auto')
|
|
@@ -305,10 +279,7 @@ def detr_for_object_detection(
|
|
|
305
279
|
|
|
306
280
|
@pxt.udf(batch_size=4)
|
|
307
281
|
def vit_for_image_classification(
|
|
308
|
-
image: Batch[PIL.Image.Image],
|
|
309
|
-
*,
|
|
310
|
-
model_id: str,
|
|
311
|
-
top_k: int = 5
|
|
282
|
+
image: Batch[PIL.Image.Image], *, model_id: str, top_k: int = 5
|
|
312
283
|
) -> Batch[dict[str, Any]]:
|
|
313
284
|
"""
|
|
314
285
|
Computes image classifications for the specified image using a Vision Transformer (ViT) model.
|
|
@@ -344,11 +315,11 @@ def vit_for_image_classification(
|
|
|
344
315
|
Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
|
|
345
316
|
Pixeltable column `image` of the table `tbl`, returning the 10 most likely classes for each image:
|
|
346
317
|
|
|
347
|
-
>>> tbl
|
|
318
|
+
>>> tbl.add_computed_column(image_class=vit_for_image_classification(
|
|
348
319
|
... tbl.image,
|
|
349
320
|
... model_id='google/vit-base-patch16-224',
|
|
350
321
|
... top_k=10
|
|
351
|
-
... )
|
|
322
|
+
... ))
|
|
352
323
|
"""
|
|
353
324
|
env.Env.get().require_package('transformers')
|
|
354
325
|
device = resolve_torch_device('auto')
|
|
@@ -380,12 +351,7 @@ def vit_for_image_classification(
|
|
|
380
351
|
|
|
381
352
|
|
|
382
353
|
@pxt.udf
|
|
383
|
-
def speech2text_for_conditional_generation(
|
|
384
|
-
audio: pxt.Audio,
|
|
385
|
-
*,
|
|
386
|
-
model_id: str,
|
|
387
|
-
language: Optional[str] = None,
|
|
388
|
-
) -> str:
|
|
354
|
+
def speech2text_for_conditional_generation(audio: pxt.Audio, *, model_id: str, language: str | None = None) -> str:
|
|
389
355
|
"""
|
|
390
356
|
Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
|
|
391
357
|
pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
|
|
@@ -408,19 +374,19 @@ def speech2text_for_conditional_generation(
|
|
|
408
374
|
Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
|
|
409
375
|
Pixeltable column `audio` of the table `tbl`:
|
|
410
376
|
|
|
411
|
-
>>> tbl
|
|
377
|
+
>>> tbl.add_computed_column(transcription=speech2text_for_conditional_generation(
|
|
412
378
|
... tbl.audio,
|
|
413
379
|
... model_id='facebook/s2t-small-librispeech-asr'
|
|
414
|
-
... )
|
|
380
|
+
... ))
|
|
415
381
|
|
|
416
382
|
Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
|
|
417
383
|
Pixeltable column `audio` of the table `tbl`, translating the audio to French:
|
|
418
384
|
|
|
419
|
-
>>> tbl
|
|
385
|
+
>>> tbl.add_computed_column(translation=speech2text_for_conditional_generation(
|
|
420
386
|
... tbl.audio,
|
|
421
387
|
... model_id='facebook/s2t-medium-mustc-multilingual-st',
|
|
422
388
|
... language='fr'
|
|
423
|
-
... )
|
|
389
|
+
... ))
|
|
424
390
|
"""
|
|
425
391
|
env.Env.get().require_package('transformers')
|
|
426
392
|
env.Env.get().require_package('torchaudio')
|
|
@@ -428,18 +394,21 @@ def speech2text_for_conditional_generation(
|
|
|
428
394
|
device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
|
|
429
395
|
import torch
|
|
430
396
|
import torchaudio # type: ignore[import-untyped]
|
|
431
|
-
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
|
397
|
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor, Speech2TextTokenizer
|
|
432
398
|
|
|
433
399
|
model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
|
|
434
400
|
processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
|
|
401
|
+
tokenizer = processor.tokenizer
|
|
435
402
|
assert isinstance(processor, Speech2TextProcessor)
|
|
403
|
+
assert isinstance(tokenizer, Speech2TextTokenizer)
|
|
436
404
|
|
|
437
|
-
if language is not None and language not in
|
|
405
|
+
if language is not None and language not in tokenizer.lang_code_to_id:
|
|
438
406
|
raise excs.Error(
|
|
439
407
|
f"Language code '{language}' is not supported by the model '{model_id}'. "
|
|
440
|
-
f
|
|
408
|
+
f'Supported languages are: {list(tokenizer.lang_code_to_id.keys())}'
|
|
409
|
+
)
|
|
441
410
|
|
|
442
|
-
forced_bos_token_id:
|
|
411
|
+
forced_bos_token_id: int | None = None if language is None else tokenizer.lang_code_to_id[language]
|
|
443
412
|
|
|
444
413
|
# Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
|
|
445
414
|
model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
|
|
@@ -457,11 +426,7 @@ def speech2text_for_conditional_generation(
|
|
|
457
426
|
assert waveform.dim() == 1
|
|
458
427
|
|
|
459
428
|
with torch.no_grad():
|
|
460
|
-
inputs = processor(
|
|
461
|
-
waveform,
|
|
462
|
-
sampling_rate=model_sampling_rate,
|
|
463
|
-
return_tensors='pt'
|
|
464
|
-
)
|
|
429
|
+
inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
|
|
465
430
|
generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
|
|
466
431
|
|
|
467
432
|
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
@@ -484,7 +449,7 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
|
|
|
484
449
|
Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
|
|
485
450
|
is the image for which detections were computed:
|
|
486
451
|
|
|
487
|
-
>>> tbl
|
|
452
|
+
>>> tbl.add_computed_column(detections_coco=detr_to_coco(tbl.image, tbl.detections))
|
|
488
453
|
"""
|
|
489
454
|
bboxes, labels = detr_info['boxes'], detr_info['labels']
|
|
490
455
|
annotations = [
|
|
@@ -494,14 +459,1041 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
|
|
|
494
459
|
return {'image': {'width': image.width, 'height': image.height}, 'annotations': annotations}
|
|
495
460
|
|
|
496
461
|
|
|
497
|
-
|
|
462
|
+
@pxt.udf
|
|
463
|
+
def text_generation(text: str, *, model_id: str, model_kwargs: dict[str, Any] | None = None) -> str:
|
|
464
|
+
"""
|
|
465
|
+
Generates text using a pretrained language model. `model_id` should be a reference to a pretrained
|
|
466
|
+
[text generation model](https://huggingface.co/models?pipeline_tag=text-generation).
|
|
498
467
|
|
|
468
|
+
__Requirements:__
|
|
499
469
|
|
|
500
|
-
|
|
470
|
+
- `pip install torch transformers`
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
text: The input text to continue/complete.
|
|
474
|
+
model_id: The pretrained model to use for text generation.
|
|
475
|
+
model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`,
|
|
476
|
+
`temperature`, etc. See the
|
|
477
|
+
[Hugging Face text_generation documentation](https://huggingface.co/docs/inference-providers/en/tasks/text-generation)
|
|
478
|
+
for details.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
The generated text completion.
|
|
482
|
+
|
|
483
|
+
Examples:
|
|
484
|
+
Add a computed column that generates text completions using the `Qwen/Qwen3-0.6B` model:
|
|
485
|
+
|
|
486
|
+
>>> tbl.add_computed_column(completion=text_generation(
|
|
487
|
+
... tbl.prompt,
|
|
488
|
+
... model_id='Qwen/Qwen3-0.6B',
|
|
489
|
+
... model_kwargs={'temperature': 0.5, 'max_length': 150}
|
|
490
|
+
... ))
|
|
491
|
+
"""
|
|
492
|
+
env.Env.get().require_package('transformers')
|
|
493
|
+
device = resolve_torch_device('auto')
|
|
494
|
+
import torch
|
|
495
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
496
|
+
|
|
497
|
+
if model_kwargs is None:
|
|
498
|
+
model_kwargs = {}
|
|
499
|
+
|
|
500
|
+
model = _lookup_model(model_id, AutoModelForCausalLM.from_pretrained, device=device)
|
|
501
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
502
|
+
|
|
503
|
+
if tokenizer.pad_token is None:
|
|
504
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
505
|
+
|
|
506
|
+
with torch.no_grad():
|
|
507
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
|
|
508
|
+
outputs = model.generate(**inputs.to(device), pad_token_id=tokenizer.eos_token_id, **model_kwargs)
|
|
509
|
+
|
|
510
|
+
input_length = len(inputs['input_ids'][0])
|
|
511
|
+
generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
|
512
|
+
return generated_text
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
@pxt.udf(batch_size=16)
|
|
516
|
+
def text_classification(text: Batch[str], *, model_id: str, top_k: int = 5) -> Batch[list[dict[str, Any]]]:
|
|
517
|
+
"""
|
|
518
|
+
Classifies text using a pretrained classification model. `model_id` should be a reference to a pretrained
|
|
519
|
+
[text classification model](https://huggingface.co/models?pipeline_tag=text-classification)
|
|
520
|
+
such as BERT, RoBERTa, or DistilBERT.
|
|
521
|
+
|
|
522
|
+
__Requirements:__
|
|
523
|
+
|
|
524
|
+
- `pip install torch transformers`
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
text: The text to classify.
|
|
528
|
+
model_id: The pretrained model to use for classification.
|
|
529
|
+
top_k: The number of top predictions to return.
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
A dictionary containing classification results with scores, labels, and label text.
|
|
533
|
+
|
|
534
|
+
Examples:
|
|
535
|
+
Add a computed column for sentiment analysis:
|
|
536
|
+
|
|
537
|
+
>>> tbl.add_computed_column(sentiment=text_classification(
|
|
538
|
+
... tbl.review_text,
|
|
539
|
+
... model_id='cardiffnlp/twitter-roberta-base-sentiment-latest'
|
|
540
|
+
... ))
|
|
541
|
+
"""
|
|
542
|
+
env.Env.get().require_package('transformers')
|
|
543
|
+
device = resolve_torch_device('auto')
|
|
544
|
+
import torch
|
|
545
|
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
546
|
+
|
|
547
|
+
model = _lookup_model(model_id, AutoModelForSequenceClassification.from_pretrained, device=device)
|
|
548
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
549
|
+
|
|
550
|
+
with torch.no_grad():
|
|
551
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
552
|
+
outputs = model(**inputs.to(device))
|
|
553
|
+
logits = outputs.logits
|
|
554
|
+
|
|
555
|
+
probs = torch.softmax(logits, dim=-1)
|
|
556
|
+
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
|
557
|
+
|
|
558
|
+
results = []
|
|
559
|
+
for i in range(len(text)):
|
|
560
|
+
# Return as list of individual classification items for HuggingFace compatibility
|
|
561
|
+
classification_items = []
|
|
562
|
+
for k in range(top_k_probs.shape[1]):
|
|
563
|
+
classification_items.append(
|
|
564
|
+
{
|
|
565
|
+
'label': top_k_indices[i, k].item(),
|
|
566
|
+
'label_text': model.config.id2label[top_k_indices[i, k].item()],
|
|
567
|
+
'score': top_k_probs[i, k].item(),
|
|
568
|
+
}
|
|
569
|
+
)
|
|
570
|
+
results.append(classification_items)
|
|
571
|
+
|
|
572
|
+
return results
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
@pxt.udf(batch_size=4)
|
|
576
|
+
def image_captioning(
|
|
577
|
+
image: Batch[PIL.Image.Image], *, model_id: str, model_kwargs: dict[str, Any] | None = None
|
|
578
|
+
) -> Batch[str]:
|
|
579
|
+
"""
|
|
580
|
+
Generates captions for images using a pretrained image captioning model. `model_id` should be a reference to a
|
|
581
|
+
pretrained [image-to-text model](https://huggingface.co/models?pipeline_tag=image-to-text) such as BLIP,
|
|
582
|
+
Git, or LLaVA.
|
|
583
|
+
|
|
584
|
+
__Requirements:__
|
|
585
|
+
|
|
586
|
+
- `pip install torch transformers`
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
image: The image to caption.
|
|
590
|
+
model_id: The pretrained model to use for captioning.
|
|
591
|
+
model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
The generated caption text.
|
|
595
|
+
|
|
596
|
+
Examples:
|
|
597
|
+
Add a computed column `caption` to an existing table `tbl` that generates captions using the
|
|
598
|
+
`Salesforce/blip-image-captioning-base` model:
|
|
599
|
+
|
|
600
|
+
>>> tbl.add_computed_column(caption=image_captioning(
|
|
601
|
+
... tbl.image,
|
|
602
|
+
... model_id='Salesforce/blip-image-captioning-base',
|
|
603
|
+
... model_kwargs={'max_length': 30}
|
|
604
|
+
... ))
|
|
605
|
+
"""
|
|
606
|
+
env.Env.get().require_package('transformers')
|
|
607
|
+
device = resolve_torch_device('auto')
|
|
608
|
+
import torch
|
|
609
|
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
610
|
+
|
|
611
|
+
if model_kwargs is None:
|
|
612
|
+
model_kwargs = {}
|
|
613
|
+
|
|
614
|
+
model = _lookup_model(model_id, AutoModelForVision2Seq.from_pretrained, device=device)
|
|
615
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
616
|
+
normalized_images = [normalize_image_mode(img) for img in image]
|
|
617
|
+
|
|
618
|
+
with torch.no_grad():
|
|
619
|
+
inputs = processor(images=normalized_images, return_tensors='pt')
|
|
620
|
+
outputs = model.generate(**inputs.to(device), **model_kwargs)
|
|
621
|
+
|
|
622
|
+
captions = processor.batch_decode(outputs, skip_special_tokens=True)
|
|
623
|
+
return captions
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
@pxt.udf(batch_size=8)
|
|
627
|
+
def summarization(text: Batch[str], *, model_id: str, model_kwargs: dict[str, Any] | None = None) -> Batch[str]:
|
|
628
|
+
"""
|
|
629
|
+
Summarizes text using a pretrained summarization model. `model_id` should be a reference to a pretrained
|
|
630
|
+
[summarization model](https://huggingface.co/models?pipeline_tag=summarization) such as BART, T5, or Pegasus.
|
|
631
|
+
|
|
632
|
+
__Requirements:__
|
|
633
|
+
|
|
634
|
+
- `pip install torch transformers`
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
text: The text to summarize.
|
|
638
|
+
model_id: The pretrained model to use for summarization.
|
|
639
|
+
model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
The generated summary text.
|
|
643
|
+
|
|
644
|
+
Examples:
|
|
645
|
+
Add a computed column that summarizes documents:
|
|
646
|
+
|
|
647
|
+
>>> tbl.add_computed_column(summary=text_summarization(
|
|
648
|
+
... tbl.document_text,
|
|
649
|
+
... model_id='facebook/bart-large-cnn',
|
|
650
|
+
... max_length=100
|
|
651
|
+
... ))
|
|
652
|
+
"""
|
|
653
|
+
env.Env.get().require_package('transformers')
|
|
654
|
+
device = resolve_torch_device('auto')
|
|
655
|
+
import torch
|
|
656
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
657
|
+
|
|
658
|
+
if model_kwargs is None:
|
|
659
|
+
model_kwargs = {}
|
|
660
|
+
|
|
661
|
+
model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
|
|
662
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
663
|
+
|
|
664
|
+
with torch.no_grad():
|
|
665
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
|
|
666
|
+
outputs = model.generate(**inputs.to(device), **model_kwargs)
|
|
667
|
+
|
|
668
|
+
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@pxt.udf
|
|
672
|
+
def token_classification(
|
|
673
|
+
text: str, *, model_id: str, aggregation_strategy: Literal['simple', 'first', 'average', 'max'] = 'simple'
|
|
674
|
+
) -> list[dict[str, Any]]:
|
|
675
|
+
"""
|
|
676
|
+
Extracts named entities from text using a pretrained named entity recognition (NER) model.
|
|
677
|
+
`model_id` should be a reference to a pretrained
|
|
678
|
+
[token classification model](https://huggingface.co/models?pipeline_tag=token-classification) for NER.
|
|
679
|
+
|
|
680
|
+
__Requirements:__
|
|
681
|
+
|
|
682
|
+
- `pip install torch transformers`
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
text: The text to analyze for named entities.
|
|
686
|
+
model_id: The pretrained model to use.
|
|
687
|
+
aggregation_strategy: Method used to aggregate tokens.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
A list of dictionaries containing entity information (text, label, confidence, start, end).
|
|
691
|
+
|
|
692
|
+
Examples:
|
|
693
|
+
Add a computed column that extracts named entities:
|
|
694
|
+
|
|
695
|
+
>>> tbl.add_computed_column(entities=token_classification(
|
|
696
|
+
... tbl.text,
|
|
697
|
+
... model_id='dbmdz/bert-large-cased-finetuned-conll03-english'
|
|
698
|
+
... ))
|
|
699
|
+
"""
|
|
700
|
+
env.Env.get().require_package('transformers')
|
|
701
|
+
device = resolve_torch_device('auto')
|
|
702
|
+
import torch
|
|
703
|
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
|
704
|
+
|
|
705
|
+
# Follow direct model loading pattern like other best practice functions
|
|
706
|
+
model = _lookup_model(model_id, AutoModelForTokenClassification.from_pretrained, device=device)
|
|
707
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
708
|
+
|
|
709
|
+
# Validate aggregation strategy
|
|
710
|
+
valid_strategies = {'simple', 'first', 'average', 'max'}
|
|
711
|
+
if aggregation_strategy not in valid_strategies:
|
|
712
|
+
raise excs.Error(
|
|
713
|
+
f'Invalid aggregation_strategy {aggregation_strategy!r}. Must be one of: {", ".join(valid_strategies)}'
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
with torch.no_grad():
|
|
717
|
+
# Tokenize with special tokens and return offsets for entity extraction
|
|
718
|
+
inputs = tokenizer(
|
|
719
|
+
text,
|
|
720
|
+
return_tensors='pt',
|
|
721
|
+
truncation=True,
|
|
722
|
+
max_length=512,
|
|
723
|
+
return_offsets_mapping=True,
|
|
724
|
+
add_special_tokens=True,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Get model predictions
|
|
728
|
+
outputs = model(**{k: v.to(device) for k, v in inputs.items() if k != 'offset_mapping'})
|
|
729
|
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
|
730
|
+
|
|
731
|
+
# Get the predicted labels and confidence scores
|
|
732
|
+
predicted_token_classes = predictions.argmax(dim=-1).squeeze().tolist()
|
|
733
|
+
confidence_scores = predictions.max(dim=-1).values.squeeze().tolist()
|
|
734
|
+
|
|
735
|
+
# Handle single token case
|
|
736
|
+
if not isinstance(predicted_token_classes, list):
|
|
737
|
+
predicted_token_classes = [predicted_token_classes]
|
|
738
|
+
confidence_scores = [confidence_scores]
|
|
739
|
+
|
|
740
|
+
# Extract entities from predictions
|
|
741
|
+
entities = []
|
|
742
|
+
offset_mapping = inputs['offset_mapping'][0].tolist()
|
|
743
|
+
|
|
744
|
+
current_entity = None
|
|
745
|
+
|
|
746
|
+
for token_class, confidence, (start_offset, end_offset) in zip(
|
|
747
|
+
predicted_token_classes, confidence_scores, offset_mapping
|
|
748
|
+
):
|
|
749
|
+
# Skip special tokens (offset is (0, 0))
|
|
750
|
+
if start_offset == 0 and end_offset == 0:
|
|
751
|
+
continue
|
|
752
|
+
|
|
753
|
+
label = model.config.id2label[token_class]
|
|
754
|
+
|
|
755
|
+
# Skip 'O' (outside) labels
|
|
756
|
+
if label == 'O':
|
|
757
|
+
if current_entity:
|
|
758
|
+
entities.append(current_entity)
|
|
759
|
+
current_entity = None
|
|
760
|
+
continue
|
|
761
|
+
|
|
762
|
+
# Parse BIO/BILOU tags
|
|
763
|
+
if label.startswith('B-') or (label.startswith('I-') and current_entity is None):
|
|
764
|
+
# Begin new entity
|
|
765
|
+
if current_entity:
|
|
766
|
+
entities.append(current_entity)
|
|
767
|
+
|
|
768
|
+
entity_type = label[2:] if label.startswith(('B-', 'I-')) else label
|
|
769
|
+
current_entity = {
|
|
770
|
+
'word': text[start_offset:end_offset],
|
|
771
|
+
'entity_group': entity_type,
|
|
772
|
+
'score': float(confidence),
|
|
773
|
+
'start': start_offset,
|
|
774
|
+
'end': end_offset,
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
elif label.startswith('I-') and current_entity:
|
|
778
|
+
# Continue current entity
|
|
779
|
+
entity_type = label[2:]
|
|
780
|
+
if current_entity['entity_group'] == entity_type:
|
|
781
|
+
# Extend the current entity
|
|
782
|
+
current_entity['word'] = text[current_entity['start'] : end_offset]
|
|
783
|
+
current_entity['end'] = end_offset
|
|
784
|
+
|
|
785
|
+
# Update confidence based on aggregation strategy
|
|
786
|
+
if aggregation_strategy == 'average':
|
|
787
|
+
# Simple average (could be improved with token count weighting)
|
|
788
|
+
current_entity['score'] = (current_entity['score'] + float(confidence)) / 2
|
|
789
|
+
elif aggregation_strategy == 'max':
|
|
790
|
+
current_entity['score'] = max(current_entity['score'], float(confidence))
|
|
791
|
+
elif aggregation_strategy == 'first':
|
|
792
|
+
pass # Keep first confidence
|
|
793
|
+
# 'simple' uses the same logic as 'first'
|
|
794
|
+
else:
|
|
795
|
+
# Different entity type, start new entity
|
|
796
|
+
entities.append(current_entity)
|
|
797
|
+
current_entity = {
|
|
798
|
+
'word': text[start_offset:end_offset],
|
|
799
|
+
'entity_group': entity_type,
|
|
800
|
+
'score': float(confidence),
|
|
801
|
+
'start': start_offset,
|
|
802
|
+
'end': end_offset,
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
# Don't forget the last entity
|
|
806
|
+
if current_entity:
|
|
807
|
+
entities.append(current_entity)
|
|
808
|
+
|
|
809
|
+
return entities
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
@pxt.udf
|
|
813
|
+
def question_answering(context: str, question: str, *, model_id: str) -> dict[str, Any]:
|
|
814
|
+
"""
|
|
815
|
+
Answers questions based on provided context using a pretrained QA model. `model_id` should be a reference to a
|
|
816
|
+
pretrained [question answering model](https://huggingface.co/models?pipeline_tag=question-answering) such as
|
|
817
|
+
BERT or RoBERTa.
|
|
818
|
+
|
|
819
|
+
__Requirements:__
|
|
820
|
+
|
|
821
|
+
- `pip install torch transformers`
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
context: The context text containing the answer.
|
|
825
|
+
question: The question to answer.
|
|
826
|
+
model_id: The pretrained QA model to use.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
A dictionary containing the answer, confidence score, and start/end positions.
|
|
830
|
+
|
|
831
|
+
Examples:
|
|
832
|
+
Add a computed column that answers questions based on document context:
|
|
833
|
+
|
|
834
|
+
>>> tbl.add_computed_column(answer=question_answering(
|
|
835
|
+
... tbl.document_text,
|
|
836
|
+
... tbl.question,
|
|
837
|
+
... model_id='deepset/roberta-base-squad2'
|
|
838
|
+
... ))
|
|
839
|
+
"""
|
|
840
|
+
env.Env.get().require_package('transformers')
|
|
841
|
+
device = resolve_torch_device('auto')
|
|
842
|
+
import torch
|
|
843
|
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
|
|
844
|
+
|
|
845
|
+
model = _lookup_model(model_id, AutoModelForQuestionAnswering.from_pretrained, device=device)
|
|
846
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
847
|
+
|
|
848
|
+
with torch.no_grad():
|
|
849
|
+
# Tokenize the question and context
|
|
850
|
+
inputs = tokenizer.encode_plus(
|
|
851
|
+
question, context, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
# Get model predictions
|
|
855
|
+
outputs = model(**inputs.to(device))
|
|
856
|
+
start_scores = outputs.start_logits
|
|
857
|
+
end_scores = outputs.end_logits
|
|
858
|
+
|
|
859
|
+
# Find the tokens with the highest start and end scores
|
|
860
|
+
start_idx = torch.argmax(start_scores)
|
|
861
|
+
end_idx = torch.argmax(end_scores)
|
|
862
|
+
|
|
863
|
+
# Ensure end_idx >= start_idx
|
|
864
|
+
end_idx = torch.max(end_idx, start_idx)
|
|
865
|
+
|
|
866
|
+
# Convert token positions to string
|
|
867
|
+
input_ids = inputs['input_ids'][0]
|
|
868
|
+
|
|
869
|
+
# Extract answer tokens
|
|
870
|
+
answer_tokens = input_ids[start_idx : end_idx + 1]
|
|
871
|
+
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
|
872
|
+
|
|
873
|
+
# Calculate confidence score
|
|
874
|
+
start_probs = torch.softmax(start_scores, dim=1)
|
|
875
|
+
end_probs = torch.softmax(end_scores, dim=1)
|
|
876
|
+
confidence = float(start_probs[0][start_idx] * end_probs[0][end_idx])
|
|
877
|
+
|
|
878
|
+
return {'answer': answer.strip(), 'score': confidence, 'start': int(start_idx), 'end': int(end_idx)}
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
@pxt.udf(batch_size=8)
|
|
882
|
+
def translation(
|
|
883
|
+
text: Batch[str], *, model_id: str, src_lang: str | None = None, target_lang: str | None = None
|
|
884
|
+
) -> Batch[str]:
|
|
885
|
+
"""
|
|
886
|
+
Translates text using a pretrained translation model. `model_id` should be a reference to a pretrained
|
|
887
|
+
[translation model](https://huggingface.co/models?pipeline_tag=translation) such as MarianMT or T5.
|
|
888
|
+
|
|
889
|
+
__Requirements:__
|
|
890
|
+
|
|
891
|
+
- `pip install torch transformers sentencepiece`
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
text: The text to translate.
|
|
895
|
+
model_id: The pretrained translation model to use.
|
|
896
|
+
src_lang: Source language code (optional, can be inferred from model).
|
|
897
|
+
target_lang: Target language code (optional, can be inferred from model).
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
The translated text.
|
|
901
|
+
|
|
902
|
+
Examples:
|
|
903
|
+
Add a computed column that translates text:
|
|
904
|
+
|
|
905
|
+
>>> tbl.add_computed_column(french_text=translation(
|
|
906
|
+
... tbl.english_text,
|
|
907
|
+
... model_id='Helsinki-NLP/opus-mt-en-fr',
|
|
908
|
+
... src_lang='en',
|
|
909
|
+
... target_lang='fr'
|
|
910
|
+
... ))
|
|
911
|
+
"""
|
|
912
|
+
env.Env.get().require_package('transformers')
|
|
913
|
+
device = resolve_torch_device('auto')
|
|
914
|
+
import torch
|
|
915
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
916
|
+
|
|
917
|
+
model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
|
|
918
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
919
|
+
lang_code_to_id: dict | None = getattr(tokenizer, 'lang_code_to_id', {})
|
|
920
|
+
|
|
921
|
+
# Language validation - following speech2text_for_conditional_generation pattern
|
|
922
|
+
if src_lang is not None and src_lang not in lang_code_to_id:
|
|
923
|
+
raise excs.Error(
|
|
924
|
+
f'Source language code {src_lang!r} is not supported by the model {model_id!r}. '
|
|
925
|
+
f'Supported languages are: {list(lang_code_to_id.keys())}'
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
if target_lang is not None and target_lang not in lang_code_to_id:
|
|
929
|
+
raise excs.Error(
|
|
930
|
+
f'Target language code {target_lang!r} is not supported by the model {model_id!r}. '
|
|
931
|
+
f'Supported languages are: {list(lang_code_to_id.keys())}'
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
with torch.no_grad():
|
|
935
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
936
|
+
|
|
937
|
+
# Set forced_bos_token_id for target language if supported
|
|
938
|
+
generate_kwargs = {'max_length': 512, 'num_beams': 4, 'early_stopping': True}
|
|
939
|
+
|
|
940
|
+
if target_lang is not None:
|
|
941
|
+
generate_kwargs['forced_bos_token_id'] = lang_code_to_id[target_lang]
|
|
942
|
+
|
|
943
|
+
outputs = model.generate(**inputs.to(device), **generate_kwargs)
|
|
944
|
+
|
|
945
|
+
# Decode all outputs at once
|
|
946
|
+
translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
947
|
+
return translations
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
@pxt.udf
|
|
951
|
+
def text_to_image(
|
|
952
|
+
prompt: str,
|
|
953
|
+
*,
|
|
954
|
+
model_id: str,
|
|
955
|
+
height: int = 512,
|
|
956
|
+
width: int = 512,
|
|
957
|
+
seed: int | None = None,
|
|
958
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
959
|
+
) -> PIL.Image.Image:
|
|
960
|
+
"""
|
|
961
|
+
Generates images from text prompts using a pretrained text-to-image model. `model_id` should be a reference to a
|
|
962
|
+
pretrained [text-to-image model](https://huggingface.co/models?pipeline_tag=text-to-image) such as
|
|
963
|
+
Stable Diffusion or FLUX.
|
|
964
|
+
|
|
965
|
+
__Requirements:__
|
|
966
|
+
|
|
967
|
+
- `pip install torch transformers diffusers accelerate`
|
|
968
|
+
|
|
969
|
+
Args:
|
|
970
|
+
prompt: The text prompt describing the desired image.
|
|
971
|
+
model_id: The pretrained text-to-image model to use.
|
|
972
|
+
height: Height of the generated image in pixels.
|
|
973
|
+
width: Width of the generated image in pixels.
|
|
974
|
+
seed: Optional random seed for reproducibility.
|
|
975
|
+
model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
|
|
976
|
+
`guidance_scale`, or `negative_prompt`.
|
|
977
|
+
|
|
978
|
+
Returns:
|
|
979
|
+
The generated Image.
|
|
980
|
+
|
|
981
|
+
Examples:
|
|
982
|
+
Add a computed column that generates images from text prompts:
|
|
983
|
+
|
|
984
|
+
>>> tbl.add_computed_column(generated_image=text_to_image(
|
|
985
|
+
... tbl.prompt,
|
|
986
|
+
... model_id='stable-diffusion-v1.5/stable-diffusion-v1-5',
|
|
987
|
+
... height=512,
|
|
988
|
+
... width=512,
|
|
989
|
+
... model_kwargs={'num_inference_steps': 25},
|
|
990
|
+
... ))
|
|
991
|
+
"""
|
|
992
|
+
env.Env.get().require_package('transformers')
|
|
993
|
+
env.Env.get().require_package('diffusers')
|
|
994
|
+
env.Env.get().require_package('accelerate')
|
|
995
|
+
device = resolve_torch_device('auto', allow_mps=False)
|
|
996
|
+
import torch
|
|
997
|
+
from diffusers import AutoPipelineForText2Image
|
|
998
|
+
|
|
999
|
+
if model_kwargs is None:
|
|
1000
|
+
model_kwargs = {}
|
|
1001
|
+
|
|
1002
|
+
# Parameter validation - following best practices pattern
|
|
1003
|
+
if height <= 0 or width <= 0:
|
|
1004
|
+
raise excs.Error(f'Height ({height}) and width ({width}) must be positive integers')
|
|
1005
|
+
|
|
1006
|
+
if height % 8 != 0 or width % 8 != 0:
|
|
1007
|
+
raise excs.Error(f'Height ({height}) and width ({width}) must be divisible by 8 for most diffusion models')
|
|
1008
|
+
|
|
1009
|
+
pipeline = _lookup_model(
|
|
1010
|
+
model_id,
|
|
1011
|
+
lambda x: AutoPipelineForText2Image.from_pretrained(
|
|
1012
|
+
x,
|
|
1013
|
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
|
1014
|
+
device_map='auto' if device == 'cuda' else None,
|
|
1015
|
+
safety_checker=None, # Disable safety checker for performance
|
|
1016
|
+
requires_safety_checker=False,
|
|
1017
|
+
),
|
|
1018
|
+
device=device,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
try:
|
|
1022
|
+
if device == 'cuda' and hasattr(pipeline, 'enable_model_cpu_offload'):
|
|
1023
|
+
pipeline.enable_model_cpu_offload()
|
|
1024
|
+
if hasattr(pipeline, 'enable_memory_efficient_attention'):
|
|
1025
|
+
pipeline.enable_memory_efficient_attention()
|
|
1026
|
+
except Exception:
|
|
1027
|
+
pass # Ignore optimization failures
|
|
1028
|
+
|
|
1029
|
+
generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
|
|
1030
|
+
|
|
1031
|
+
with torch.no_grad():
|
|
1032
|
+
result = pipeline(prompt, height=height, width=width, generator=generator, **model_kwargs)
|
|
1033
|
+
return result.images[0]
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
@pxt.udf
|
|
1037
|
+
def text_to_speech(text: str, *, model_id: str, speaker_id: int | None = None, vocoder: str | None = None) -> pxt.Audio:
|
|
1038
|
+
"""
|
|
1039
|
+
Converts text to speech using a pretrained TTS model. `model_id` should be a reference to a
|
|
1040
|
+
pretrained [text-to-speech model](https://huggingface.co/models?pipeline_tag=text-to-speech).
|
|
1041
|
+
|
|
1042
|
+
__Requirements:__
|
|
1043
|
+
|
|
1044
|
+
- `pip install torch transformers datasets soundfile`
|
|
1045
|
+
|
|
1046
|
+
Args:
|
|
1047
|
+
text: The text to convert to speech.
|
|
1048
|
+
model_id: The pretrained TTS model to use.
|
|
1049
|
+
speaker_id: Speaker ID for multi-speaker models.
|
|
1050
|
+
vocoder: Optional vocoder model for higher quality audio.
|
|
1051
|
+
|
|
1052
|
+
Returns:
|
|
1053
|
+
The generated audio file.
|
|
1054
|
+
|
|
1055
|
+
Examples:
|
|
1056
|
+
Add a computed column that converts text to speech:
|
|
1057
|
+
|
|
1058
|
+
>>> tbl.add_computed_column(audio=text_to_speech(
|
|
1059
|
+
... tbl.text_content,
|
|
1060
|
+
... model_id='microsoft/speecht5_tts',
|
|
1061
|
+
... speaker_id=0
|
|
1062
|
+
... ))
|
|
1063
|
+
"""
|
|
1064
|
+
env.Env.get().require_package('transformers')
|
|
1065
|
+
env.Env.get().require_package('datasets')
|
|
1066
|
+
env.Env.get().require_package('soundfile')
|
|
1067
|
+
device = resolve_torch_device('auto')
|
|
1068
|
+
import datasets # type: ignore[import-untyped]
|
|
1069
|
+
import soundfile as sf # type: ignore[import-untyped]
|
|
1070
|
+
import torch
|
|
1071
|
+
from transformers import (
|
|
1072
|
+
AutoModelForTextToWaveform,
|
|
1073
|
+
AutoProcessor,
|
|
1074
|
+
BarkModel,
|
|
1075
|
+
SpeechT5ForTextToSpeech,
|
|
1076
|
+
SpeechT5HifiGan,
|
|
1077
|
+
SpeechT5Processor,
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
# Model loading with error handling - following best practices pattern
|
|
1081
|
+
if 'speecht5' in model_id.lower():
|
|
1082
|
+
model = _lookup_model(model_id, SpeechT5ForTextToSpeech.from_pretrained, device=device)
|
|
1083
|
+
processor = _lookup_processor(model_id, SpeechT5Processor.from_pretrained)
|
|
1084
|
+
vocoder_model_id = vocoder or 'microsoft/speecht5_hifigan'
|
|
1085
|
+
vocoder_model = _lookup_model(vocoder_model_id, SpeechT5HifiGan.from_pretrained, device=device)
|
|
1086
|
+
|
|
1087
|
+
elif 'bark' in model_id.lower():
|
|
1088
|
+
model = _lookup_model(model_id, BarkModel.from_pretrained, device=device)
|
|
1089
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1090
|
+
vocoder_model = None
|
|
1091
|
+
|
|
1092
|
+
else:
|
|
1093
|
+
model = _lookup_model(model_id, AutoModelForTextToWaveform.from_pretrained, device=device)
|
|
1094
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1095
|
+
vocoder_model = None
|
|
1096
|
+
|
|
1097
|
+
# Load speaker embeddings once for SpeechT5 (following speech2text pattern)
|
|
1098
|
+
speaker_embeddings = None
|
|
1099
|
+
if 'speecht5' in model_id.lower():
|
|
1100
|
+
ds: datasets.Dataset
|
|
1101
|
+
if len(_speecht5_embeddings_dataset) == 0:
|
|
1102
|
+
ds = datasets.load_dataset(
|
|
1103
|
+
'Matthijs/cmu-arctic-xvectors', split='validation', revision='refs/convert/parquet'
|
|
1104
|
+
)
|
|
1105
|
+
_speecht5_embeddings_dataset.append(ds)
|
|
1106
|
+
else:
|
|
1107
|
+
assert len(_speecht5_embeddings_dataset) == 1
|
|
1108
|
+
ds = _speecht5_embeddings_dataset[0]
|
|
1109
|
+
speaker_embeddings = torch.tensor(ds[speaker_id or 7306]['xvector']).unsqueeze(0).to(device)
|
|
1110
|
+
|
|
1111
|
+
with torch.no_grad():
|
|
1112
|
+
# Generate speech based on model type
|
|
1113
|
+
if 'speecht5' in model_id.lower():
|
|
1114
|
+
inputs = processor(text=text, return_tensors='pt').to(device)
|
|
1115
|
+
speech = model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=vocoder_model)
|
|
1116
|
+
audio_np = speech.cpu().numpy()
|
|
1117
|
+
sample_rate = 16000
|
|
1118
|
+
|
|
1119
|
+
elif 'bark' in model_id.lower():
|
|
1120
|
+
inputs = processor(text, return_tensors='pt').to(device)
|
|
1121
|
+
audio_array = model.generate(**inputs)
|
|
1122
|
+
audio_np = audio_array.cpu().numpy().squeeze()
|
|
1123
|
+
sample_rate = getattr(model.generation_config, 'sample_rate', 24000)
|
|
1124
|
+
|
|
1125
|
+
else:
|
|
1126
|
+
# Generic approach for other TTS models
|
|
1127
|
+
inputs = processor(text, return_tensors='pt').to(device)
|
|
1128
|
+
audio_output = model(**inputs)
|
|
1129
|
+
audio_np = audio_output.waveform.cpu().numpy().squeeze()
|
|
1130
|
+
sample_rate = getattr(model.config, 'sample_rate', 22050)
|
|
1131
|
+
|
|
1132
|
+
# Normalize audio - following consistent pattern
|
|
1133
|
+
if audio_np.dtype != np.float32:
|
|
1134
|
+
audio_np = audio_np.astype(np.float32)
|
|
1135
|
+
|
|
1136
|
+
if np.max(np.abs(audio_np)) > 0:
|
|
1137
|
+
audio_np = audio_np / np.max(np.abs(audio_np)) * 0.9
|
|
1138
|
+
|
|
1139
|
+
# Create output file
|
|
1140
|
+
output_filename = str(TempStore.create_path(extension='.wav'))
|
|
1141
|
+
sf.write(output_filename, audio_np, sample_rate, format='WAV', subtype='PCM_16')
|
|
1142
|
+
return output_filename
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
@pxt.udf
|
|
1146
|
+
def image_to_image(
|
|
1147
|
+
image: PIL.Image.Image,
|
|
1148
|
+
prompt: str,
|
|
1149
|
+
*,
|
|
1150
|
+
model_id: str,
|
|
1151
|
+
seed: int | None = None,
|
|
1152
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
1153
|
+
) -> PIL.Image.Image:
|
|
1154
|
+
"""
|
|
1155
|
+
Transforms input images based on text prompts using a pretrained image-to-image model.
|
|
1156
|
+
`model_id` should be a reference to a pretrained
|
|
1157
|
+
[image-to-image model](https://huggingface.co/models?pipeline_tag=image-to-image).
|
|
1158
|
+
|
|
1159
|
+
__Requirements:__
|
|
1160
|
+
|
|
1161
|
+
- `pip install torch transformers diffusers accelerate`
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
image: The input image to transform.
|
|
1165
|
+
prompt: The text prompt describing the desired transformation.
|
|
1166
|
+
model_id: The pretrained image-to-image model to use.
|
|
1167
|
+
seed: Random seed for reproducibility.
|
|
1168
|
+
model_kwargs: Additional keyword arguments to pass to the model, such as `strength`,
|
|
1169
|
+
`guidance_scale`, or `num_inference_steps`.
|
|
1170
|
+
|
|
1171
|
+
Returns:
|
|
1172
|
+
The transformed image.
|
|
1173
|
+
|
|
1174
|
+
Examples:
|
|
1175
|
+
Add a computed column that transforms images based on prompts:
|
|
1176
|
+
|
|
1177
|
+
>>> tbl.add_computed_column(transformed=image_to_image(
|
|
1178
|
+
... tbl.source_image,
|
|
1179
|
+
... tbl.transformation_prompt,
|
|
1180
|
+
... model_id='runwayml/stable-diffusion-v1-5'
|
|
1181
|
+
... ))
|
|
1182
|
+
"""
|
|
1183
|
+
env.Env.get().require_package('transformers')
|
|
1184
|
+
env.Env.get().require_package('diffusers')
|
|
1185
|
+
env.Env.get().require_package('accelerate')
|
|
1186
|
+
device = resolve_torch_device('auto')
|
|
1187
|
+
import torch
|
|
1188
|
+
from diffusers import StableDiffusionImg2ImgPipeline
|
|
1189
|
+
|
|
1190
|
+
if model_kwargs is None:
|
|
1191
|
+
model_kwargs = {}
|
|
1192
|
+
|
|
1193
|
+
pipe = _lookup_model(
|
|
1194
|
+
model_id,
|
|
1195
|
+
lambda x: StableDiffusionImg2ImgPipeline.from_pretrained(
|
|
1196
|
+
x,
|
|
1197
|
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
|
1198
|
+
safety_checker=None,
|
|
1199
|
+
requires_safety_checker=False,
|
|
1200
|
+
),
|
|
1201
|
+
device=device,
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
try:
|
|
1205
|
+
if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
|
|
1206
|
+
pipe.enable_model_cpu_offload()
|
|
1207
|
+
if hasattr(pipe, 'enable_memory_efficient_attention'):
|
|
1208
|
+
pipe.enable_memory_efficient_attention()
|
|
1209
|
+
except Exception:
|
|
1210
|
+
pass # Ignore optimization failures
|
|
1211
|
+
|
|
1212
|
+
generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
|
|
1213
|
+
|
|
1214
|
+
processed_image = image.convert('RGB')
|
|
1215
|
+
|
|
1216
|
+
with torch.no_grad():
|
|
1217
|
+
result = pipe(prompt=prompt, image=processed_image, generator=generator, **model_kwargs)
|
|
1218
|
+
return result.images[0]
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
@pxt.udf
|
|
1222
|
+
def automatic_speech_recognition(
|
|
1223
|
+
audio: pxt.Audio,
|
|
1224
|
+
*,
|
|
501
1225
|
model_id: str,
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
1226
|
+
language: str | None = None,
|
|
1227
|
+
chunk_length_s: int | None = None,
|
|
1228
|
+
return_timestamps: bool = False,
|
|
1229
|
+
) -> str:
|
|
1230
|
+
"""
|
|
1231
|
+
Transcribes speech to text using a pretrained ASR model. `model_id` should be a reference to a
|
|
1232
|
+
pretrained [automatic-speech-recognition model](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition).
|
|
1233
|
+
|
|
1234
|
+
This is a **generic function** that works with many ASR model families. For production use with
|
|
1235
|
+
specific models, consider specialized functions like `whisper.transcribe()` or
|
|
1236
|
+
`speech2text_for_conditional_generation()`.
|
|
1237
|
+
|
|
1238
|
+
__Requirements:__
|
|
1239
|
+
|
|
1240
|
+
- `pip install torch transformers torchaudio`
|
|
1241
|
+
|
|
1242
|
+
__Recommended Models:__
|
|
1243
|
+
|
|
1244
|
+
- **OpenAI Whisper**: `openai/whisper-tiny.en`, `openai/whisper-small`, `openai/whisper-base`
|
|
1245
|
+
- **Facebook Wav2Vec2**: `facebook/wav2vec2-base-960h`, `facebook/wav2vec2-large-960h-lv60-self`
|
|
1246
|
+
- **Microsoft SpeechT5**: `microsoft/speecht5_asr`
|
|
1247
|
+
- **Meta MMS (Multilingual)**: `facebook/mms-1b-all`
|
|
1248
|
+
|
|
1249
|
+
Args:
|
|
1250
|
+
audio: The audio file(s) to transcribe.
|
|
1251
|
+
model_id: The pretrained ASR model to use.
|
|
1252
|
+
language: Language code for multilingual models (e.g., 'en', 'es', 'fr').
|
|
1253
|
+
chunk_length_s: Maximum length of audio chunks in seconds for long audio processing.
|
|
1254
|
+
return_timestamps: Whether to return word-level timestamps (model dependent).
|
|
1255
|
+
|
|
1256
|
+
Returns:
|
|
1257
|
+
The transcribed text.
|
|
1258
|
+
|
|
1259
|
+
Examples:
|
|
1260
|
+
Add a computed column that transcribes audio files:
|
|
1261
|
+
|
|
1262
|
+
>>> tbl.add_computed_column(transcription=automatic_speech_recognition(
|
|
1263
|
+
... tbl.audio_file,
|
|
1264
|
+
... model_id='openai/whisper-tiny.en' # Recommended
|
|
1265
|
+
... ))
|
|
1266
|
+
|
|
1267
|
+
Transcribe with language specification:
|
|
1268
|
+
|
|
1269
|
+
>>> tbl.add_computed_column(transcription=automatic_speech_recognition(
|
|
1270
|
+
... tbl.audio_file,
|
|
1271
|
+
... model_id='facebook/mms-1b-all',
|
|
1272
|
+
... language='en'
|
|
1273
|
+
... ))
|
|
1274
|
+
"""
|
|
1275
|
+
env.Env.get().require_package('transformers')
|
|
1276
|
+
env.Env.get().require_package('torchaudio')
|
|
1277
|
+
device = resolve_torch_device('auto', allow_mps=False) # Following speech2text pattern
|
|
1278
|
+
import torch
|
|
1279
|
+
import torchaudio
|
|
1280
|
+
|
|
1281
|
+
# Try to load model and processor using direct model loading - following speech2text pattern
|
|
1282
|
+
# Handle different ASR model types
|
|
1283
|
+
if 'whisper' in model_id.lower():
|
|
1284
|
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
|
1285
|
+
|
|
1286
|
+
model = _lookup_model(model_id, WhisperForConditionalGeneration.from_pretrained, device=device)
|
|
1287
|
+
processor = _lookup_processor(model_id, WhisperProcessor.from_pretrained)
|
|
1288
|
+
|
|
1289
|
+
# Language validation for Whisper - following speech2text pattern
|
|
1290
|
+
if language is not None and hasattr(processor.tokenizer, 'get_decoder_prompt_ids'):
|
|
1291
|
+
try:
|
|
1292
|
+
# Test if language is supported
|
|
1293
|
+
_ = processor.tokenizer.get_decoder_prompt_ids(language=language)
|
|
1294
|
+
except Exception:
|
|
1295
|
+
raise excs.Error(
|
|
1296
|
+
f"Language code '{language}' is not supported by Whisper model '{model_id}'. "
|
|
1297
|
+
f"Try common codes like 'en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh'."
|
|
1298
|
+
) from None
|
|
1299
|
+
|
|
1300
|
+
elif 'wav2vec2' in model_id.lower():
|
|
1301
|
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
1302
|
+
|
|
1303
|
+
model = _lookup_model(model_id, Wav2Vec2ForCTC.from_pretrained, device=device)
|
|
1304
|
+
processor = _lookup_processor(model_id, Wav2Vec2Processor.from_pretrained)
|
|
1305
|
+
|
|
1306
|
+
elif 'speech_to_text' in model_id.lower() or 's2t' in model_id.lower():
|
|
1307
|
+
# Use the existing speech2text function for these models
|
|
1308
|
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
|
1309
|
+
|
|
1310
|
+
model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
|
|
1311
|
+
processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
|
|
1312
|
+
|
|
1313
|
+
else:
|
|
1314
|
+
# Generic fallback using Auto classes
|
|
1315
|
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
1316
|
+
|
|
1317
|
+
try:
|
|
1318
|
+
model = _lookup_model(model_id, AutoModelForSpeechSeq2Seq.from_pretrained, device=device)
|
|
1319
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1320
|
+
except Exception:
|
|
1321
|
+
# Fallback to CTC models
|
|
1322
|
+
from transformers import AutoModelForCTC
|
|
1323
|
+
|
|
1324
|
+
model = _lookup_model(model_id, AutoModelForCTC.from_pretrained, device=device)
|
|
1325
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1326
|
+
|
|
1327
|
+
# Get model's expected sampling rate - following speech2text pattern
|
|
1328
|
+
model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
|
|
1329
|
+
|
|
1330
|
+
# Load and preprocess audio - following speech2text pattern
|
|
1331
|
+
waveform, sampling_rate = torchaudio.load(audio)
|
|
1332
|
+
|
|
1333
|
+
# Resample if necessary
|
|
1334
|
+
if sampling_rate != model_sampling_rate:
|
|
1335
|
+
waveform = torchaudio.transforms.Resample(sampling_rate, model_sampling_rate)(waveform)
|
|
1336
|
+
|
|
1337
|
+
# Convert to mono if stereo
|
|
1338
|
+
if waveform.dim() == 2:
|
|
1339
|
+
waveform = torch.mean(waveform, dim=0)
|
|
1340
|
+
assert waveform.dim() == 1
|
|
1341
|
+
|
|
1342
|
+
with torch.no_grad():
|
|
1343
|
+
# Process audio with the model
|
|
1344
|
+
inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
|
|
1345
|
+
|
|
1346
|
+
# Handle different model types for generation
|
|
1347
|
+
if 'whisper' in model_id.lower():
|
|
1348
|
+
# Whisper-specific generation
|
|
1349
|
+
generate_kwargs = {}
|
|
1350
|
+
if language is not None:
|
|
1351
|
+
generate_kwargs['language'] = language
|
|
1352
|
+
if return_timestamps:
|
|
1353
|
+
generate_kwargs['return_timestamps'] = 'word' if return_timestamps else None
|
|
1354
|
+
|
|
1355
|
+
generated_ids = model.generate(**inputs.to(device), **generate_kwargs)
|
|
1356
|
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
1357
|
+
|
|
1358
|
+
elif hasattr(model, 'generate'):
|
|
1359
|
+
# Seq2Seq models (Speech2Text, etc.)
|
|
1360
|
+
generated_ids = model.generate(**inputs.to(device))
|
|
1361
|
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
1362
|
+
|
|
1363
|
+
else:
|
|
1364
|
+
# CTC models (Wav2Vec2, etc.)
|
|
1365
|
+
logits = model(**inputs.to(device)).logits
|
|
1366
|
+
predicted_ids = torch.argmax(logits, dim=-1)
|
|
1367
|
+
transcription = processor.batch_decode(predicted_ids)[0]
|
|
1368
|
+
|
|
1369
|
+
return transcription.strip()
|
|
1370
|
+
|
|
1371
|
+
|
|
1372
|
+
@pxt.udf
|
|
1373
|
+
def image_to_video(
|
|
1374
|
+
image: PIL.Image.Image,
|
|
1375
|
+
*,
|
|
1376
|
+
model_id: str,
|
|
1377
|
+
num_frames: int = 25,
|
|
1378
|
+
fps: int = 6,
|
|
1379
|
+
seed: int | None = None,
|
|
1380
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
1381
|
+
) -> pxt.Video:
|
|
1382
|
+
"""
|
|
1383
|
+
Generates videos from input images using a pretrained image-to-video model.
|
|
1384
|
+
`model_id` should be a reference to a pretrained
|
|
1385
|
+
[image-to-video model](https://huggingface.co/models?pipeline_tag=image-to-video).
|
|
1386
|
+
|
|
1387
|
+
__Requirements:__
|
|
1388
|
+
|
|
1389
|
+
- `pip install torch transformers diffusers accelerate`
|
|
1390
|
+
|
|
1391
|
+
Args:
|
|
1392
|
+
image: The input image to animate into a video.
|
|
1393
|
+
model_id: The pretrained image-to-video model to use.
|
|
1394
|
+
num_frames: Number of video frames to generate.
|
|
1395
|
+
fps: Frames per second for the output video.
|
|
1396
|
+
seed: Random seed for reproducibility.
|
|
1397
|
+
model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
|
|
1398
|
+
`motion_bucket_id`, or `guidance_scale`.
|
|
1399
|
+
|
|
1400
|
+
Returns:
|
|
1401
|
+
The generated video file.
|
|
1402
|
+
|
|
1403
|
+
Examples:
|
|
1404
|
+
Add a computed column that creates videos from images:
|
|
1405
|
+
|
|
1406
|
+
>>> tbl.add_computed_column(video=image_to_video(
|
|
1407
|
+
... tbl.input_image,
|
|
1408
|
+
... model_id='stabilityai/stable-video-diffusion-img2vid-xt',
|
|
1409
|
+
... num_frames=25,
|
|
1410
|
+
... fps=7
|
|
1411
|
+
... ))
|
|
1412
|
+
"""
|
|
1413
|
+
env.Env.get().require_package('transformers')
|
|
1414
|
+
env.Env.get().require_package('diffusers')
|
|
1415
|
+
env.Env.get().require_package('accelerate')
|
|
1416
|
+
device = resolve_torch_device('auto', allow_mps=False)
|
|
1417
|
+
import numpy as np
|
|
1418
|
+
import torch
|
|
1419
|
+
from diffusers import StableVideoDiffusionPipeline
|
|
1420
|
+
|
|
1421
|
+
if model_kwargs is None:
|
|
1422
|
+
model_kwargs = {}
|
|
1423
|
+
|
|
1424
|
+
# Parameter validation - following best practices pattern
|
|
1425
|
+
if num_frames < 1:
|
|
1426
|
+
raise excs.Error(f'num_frames must be at least 1, got {num_frames}')
|
|
1427
|
+
|
|
1428
|
+
if num_frames > 25:
|
|
1429
|
+
raise excs.Error(f'num_frames cannot exceed 25 for most video diffusion models, got {num_frames}')
|
|
1430
|
+
|
|
1431
|
+
if fps < 1:
|
|
1432
|
+
raise excs.Error(f'fps must be at least 1, got {fps}')
|
|
1433
|
+
|
|
1434
|
+
if fps > 60:
|
|
1435
|
+
raise excs.Error(f'fps should not exceed 60 for reasonable video generation, got {fps}')
|
|
1436
|
+
|
|
1437
|
+
pipe = _lookup_model(
|
|
1438
|
+
model_id,
|
|
1439
|
+
lambda x: StableVideoDiffusionPipeline.from_pretrained(
|
|
1440
|
+
x,
|
|
1441
|
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
|
1442
|
+
variant='fp16' if device == 'cuda' else None,
|
|
1443
|
+
),
|
|
1444
|
+
device=device,
|
|
1445
|
+
)
|
|
1446
|
+
|
|
1447
|
+
try:
|
|
1448
|
+
if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
|
|
1449
|
+
pipe.enable_model_cpu_offload()
|
|
1450
|
+
if hasattr(pipe, 'enable_memory_efficient_attention'):
|
|
1451
|
+
pipe.enable_memory_efficient_attention()
|
|
1452
|
+
except Exception:
|
|
1453
|
+
pass # Ignore optimization failures
|
|
1454
|
+
|
|
1455
|
+
generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
|
|
1456
|
+
|
|
1457
|
+
# Ensure image is in RGB mode and proper size
|
|
1458
|
+
processed_image = image.convert('RGB')
|
|
1459
|
+
target_width, target_height = 512, 320
|
|
1460
|
+
processed_image = processed_image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)
|
|
1461
|
+
|
|
1462
|
+
# Generate video frames with proper error handling
|
|
1463
|
+
with torch.no_grad():
|
|
1464
|
+
result = pipe(image=processed_image, num_frames=num_frames, generator=generator, **model_kwargs)
|
|
1465
|
+
frames = result.frames[0]
|
|
1466
|
+
|
|
1467
|
+
# Create output video file
|
|
1468
|
+
output_path = str(TempStore.create_path(extension='.mp4'))
|
|
1469
|
+
|
|
1470
|
+
with av.open(output_path, mode='w') as container:
|
|
1471
|
+
stream = container.add_stream('h264', rate=fps)
|
|
1472
|
+
stream.width = target_width
|
|
1473
|
+
stream.height = target_height
|
|
1474
|
+
stream.pix_fmt = 'yuv420p'
|
|
1475
|
+
|
|
1476
|
+
# Set codec options for better compatibility
|
|
1477
|
+
stream.codec_context.options = {'crf': '23', 'preset': 'medium'}
|
|
1478
|
+
|
|
1479
|
+
for frame_pil in frames:
|
|
1480
|
+
# Convert PIL to numpy array
|
|
1481
|
+
frame_array = np.array(frame_pil)
|
|
1482
|
+
# Create av VideoFrame
|
|
1483
|
+
av_frame = av.VideoFrame.from_ndarray(frame_array, format='rgb24')
|
|
1484
|
+
# Encode and mux
|
|
1485
|
+
for packet in stream.encode(av_frame):
|
|
1486
|
+
container.mux(packet)
|
|
1487
|
+
|
|
1488
|
+
# Flush encoder
|
|
1489
|
+
for packet in stream.encode():
|
|
1490
|
+
container.mux(packet)
|
|
1491
|
+
|
|
1492
|
+
return output_path
|
|
1493
|
+
|
|
1494
|
+
|
|
1495
|
+
def _lookup_model(
|
|
1496
|
+
model_id: str, create: Callable[..., T], device: str | None = None, pass_device_to_create: bool = False
|
|
505
1497
|
) -> T:
|
|
506
1498
|
from torch import nn
|
|
507
1499
|
|
|
@@ -526,12 +1518,13 @@ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
|
|
|
526
1518
|
return _processor_cache[key]
|
|
527
1519
|
|
|
528
1520
|
|
|
529
|
-
_model_cache: dict[tuple[str, Callable,
|
|
1521
|
+
_model_cache: dict[tuple[str, Callable, str | None], Any] = {}
|
|
1522
|
+
_speecht5_embeddings_dataset: list[Any] = [] # contains only the speecht5 embeddings loaded by text_to_speech()
|
|
530
1523
|
_processor_cache: dict[tuple[str, Callable], Any] = {}
|
|
531
1524
|
|
|
532
1525
|
|
|
533
1526
|
__all__ = local_public_names(__name__)
|
|
534
1527
|
|
|
535
1528
|
|
|
536
|
-
def __dir__():
|
|
1529
|
+
def __dir__() -> list[str]:
|
|
537
1530
|
return __all__
|