pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +83 -19
- pixeltable/_query.py +1444 -0
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +7 -4
- pixeltable/catalog/catalog.py +2394 -119
- pixeltable/catalog/column.py +225 -104
- pixeltable/catalog/dir.py +38 -9
- pixeltable/catalog/globals.py +53 -34
- pixeltable/catalog/insertable_table.py +265 -115
- pixeltable/catalog/path.py +80 -17
- pixeltable/catalog/schema_object.py +28 -43
- pixeltable/catalog/table.py +1270 -677
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +1270 -751
- pixeltable/catalog/table_version_handle.py +109 -0
- pixeltable/catalog/table_version_path.py +137 -42
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +251 -134
- pixeltable/config.py +215 -0
- pixeltable/env.py +736 -285
- pixeltable/exceptions.py +26 -2
- pixeltable/exec/__init__.py +7 -2
- pixeltable/exec/aggregation_node.py +39 -21
- pixeltable/exec/cache_prefetch_node.py +87 -109
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +25 -28
- pixeltable/exec/data_row_batch.py +11 -46
- pixeltable/exec/exec_context.py +26 -11
- pixeltable/exec/exec_node.py +35 -27
- pixeltable/exec/expr_eval/__init__.py +3 -0
- pixeltable/exec/expr_eval/evaluators.py +365 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
- pixeltable/exec/expr_eval/globals.py +200 -0
- pixeltable/exec/expr_eval/row_buffer.py +74 -0
- pixeltable/exec/expr_eval/schedulers.py +413 -0
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +35 -27
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +44 -29
- pixeltable/exec/sql_node.py +414 -115
- pixeltable/exprs/__init__.py +8 -5
- pixeltable/exprs/arithmetic_expr.py +79 -45
- pixeltable/exprs/array_slice.py +5 -5
- pixeltable/exprs/column_property_ref.py +40 -26
- pixeltable/exprs/column_ref.py +254 -61
- pixeltable/exprs/comparison.py +14 -9
- pixeltable/exprs/compound_predicate.py +9 -10
- pixeltable/exprs/data_row.py +213 -72
- pixeltable/exprs/expr.py +270 -104
- pixeltable/exprs/expr_dict.py +6 -5
- pixeltable/exprs/expr_set.py +20 -11
- pixeltable/exprs/function_call.py +383 -284
- pixeltable/exprs/globals.py +18 -5
- pixeltable/exprs/in_predicate.py +7 -7
- pixeltable/exprs/inline_expr.py +37 -37
- pixeltable/exprs/is_null.py +8 -4
- pixeltable/exprs/json_mapper.py +120 -54
- pixeltable/exprs/json_path.py +90 -60
- pixeltable/exprs/literal.py +61 -16
- pixeltable/exprs/method_ref.py +7 -6
- pixeltable/exprs/object_ref.py +19 -8
- pixeltable/exprs/row_builder.py +238 -75
- pixeltable/exprs/rowid_ref.py +53 -15
- pixeltable/exprs/similarity_expr.py +65 -50
- pixeltable/exprs/sql_element_cache.py +5 -5
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/exprs/type_cast.py +25 -13
- pixeltable/exprs/variable.py +2 -2
- pixeltable/func/__init__.py +9 -5
- pixeltable/func/aggregate_function.py +197 -92
- pixeltable/func/callable_function.py +119 -35
- pixeltable/func/expr_template_function.py +101 -48
- pixeltable/func/function.py +375 -62
- pixeltable/func/function_registry.py +20 -19
- pixeltable/func/globals.py +6 -5
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +151 -35
- pixeltable/func/signature.py +178 -49
- pixeltable/func/tools.py +164 -0
- pixeltable/func/udf.py +176 -53
- pixeltable/functions/__init__.py +44 -4
- pixeltable/functions/anthropic.py +226 -47
- pixeltable/functions/audio.py +148 -11
- pixeltable/functions/bedrock.py +137 -0
- pixeltable/functions/date.py +188 -0
- pixeltable/functions/deepseek.py +113 -0
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +72 -20
- pixeltable/functions/gemini.py +249 -0
- pixeltable/functions/globals.py +208 -53
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1088 -95
- pixeltable/functions/image.py +155 -84
- pixeltable/functions/json.py +8 -11
- pixeltable/functions/llama_cpp.py +31 -19
- pixeltable/functions/math.py +169 -0
- pixeltable/functions/mistralai.py +50 -75
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +29 -36
- pixeltable/functions/openai.py +548 -160
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +15 -14
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +310 -85
- pixeltable/functions/timestamp.py +37 -19
- pixeltable/functions/together.py +77 -120
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +7 -2
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1528 -117
- pixeltable/functions/vision.py +26 -26
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +19 -10
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/functions/yolox.py +112 -0
- pixeltable/globals.py +716 -236
- pixeltable/index/__init__.py +3 -1
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +32 -22
- pixeltable/index/embedding_index.py +155 -92
- pixeltable/io/__init__.py +12 -7
- pixeltable/io/datarows.py +140 -0
- pixeltable/io/external_store.py +83 -125
- pixeltable/io/fiftyone.py +24 -33
- pixeltable/io/globals.py +47 -182
- pixeltable/io/hf_datasets.py +96 -127
- pixeltable/io/label_studio.py +171 -156
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +136 -115
- pixeltable/io/parquet.py +40 -153
- pixeltable/io/table_data_conduit.py +702 -0
- pixeltable/io/utils.py +100 -0
- pixeltable/iterators/__init__.py +8 -4
- pixeltable/iterators/audio.py +207 -0
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +144 -87
- pixeltable/iterators/image.py +17 -38
- pixeltable/iterators/string.py +15 -12
- pixeltable/iterators/video.py +523 -127
- pixeltable/metadata/__init__.py +33 -8
- pixeltable/metadata/converters/convert_10.py +2 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_15.py +15 -11
- pixeltable/metadata/converters/convert_16.py +4 -5
- pixeltable/metadata/converters/convert_17.py +4 -5
- pixeltable/metadata/converters/convert_18.py +4 -6
- pixeltable/metadata/converters/convert_19.py +6 -9
- pixeltable/metadata/converters/convert_20.py +3 -6
- pixeltable/metadata/converters/convert_21.py +6 -8
- pixeltable/metadata/converters/convert_22.py +3 -2
- pixeltable/metadata/converters/convert_23.py +33 -0
- pixeltable/metadata/converters/convert_24.py +55 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/convert_26.py +23 -0
- pixeltable/metadata/converters/convert_27.py +29 -0
- pixeltable/metadata/converters/convert_28.py +13 -0
- pixeltable/metadata/converters/convert_29.py +110 -0
- pixeltable/metadata/converters/convert_30.py +63 -0
- pixeltable/metadata/converters/convert_31.py +11 -0
- pixeltable/metadata/converters/convert_32.py +15 -0
- pixeltable/metadata/converters/convert_33.py +17 -0
- pixeltable/metadata/converters/convert_34.py +21 -0
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +44 -18
- pixeltable/metadata/notes.py +21 -0
- pixeltable/metadata/schema.py +185 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +616 -225
- pixeltable/share/__init__.py +3 -0
- pixeltable/share/packager.py +797 -0
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +349 -0
- pixeltable/store.py +398 -232
- pixeltable/type_system.py +730 -267
- pixeltable/utils/__init__.py +40 -0
- pixeltable/utils/arrow.py +201 -29
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +26 -27
- pixeltable/utils/code.py +4 -4
- pixeltable/utils/console_output.py +46 -0
- pixeltable/utils/coroutine.py +24 -0
- pixeltable/utils/dbms.py +92 -0
- pixeltable/utils/description_helper.py +11 -12
- pixeltable/utils/documents.py +60 -61
- pixeltable/utils/exception_handler.py +36 -0
- pixeltable/utils/filecache.py +38 -22
- pixeltable/utils/formatter.py +88 -51
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +14 -13
- pixeltable/utils/iceberg.py +13 -0
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +20 -20
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +32 -5
- pixeltable/utils/system.py +30 -0
- pixeltable/utils/transactional_directory.py +4 -3
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -36
- pixeltable/catalog/path_dict.py +0 -141
- pixeltable/dataframe.py +0 -894
- pixeltable/exec/expr_eval_node.py +0 -232
- pixeltable/ext/__init__.py +0 -14
- pixeltable/ext/functions/__init__.py +0 -8
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/ext/functions/yolox.py +0 -157
- pixeltable/tool/create_test_db_dump.py +0 -311
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable/utils/media_store.py +0 -76
- pixeltable/utils/s3.py +0 -16
- pixeltable-0.2.26.dist-info/METADATA +0 -400
- pixeltable-0.2.26.dist-info/RECORD +0 -156
- pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/functions/vision.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Pixeltable
|
|
2
|
+
Pixeltable UDFs for Computer Vision.
|
|
3
3
|
|
|
4
4
|
Example:
|
|
5
5
|
```python
|
|
@@ -14,7 +14,7 @@ t.select(pxtv.draw_bounding_boxes(t.img, boxes=t.boxes, label=t.labels)).collect
|
|
|
14
14
|
import colorsys
|
|
15
15
|
import hashlib
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Any
|
|
17
|
+
from typing import Any
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
import PIL.Image
|
|
@@ -205,7 +205,9 @@ def eval_detections(
|
|
|
205
205
|
pred_filter = pred_classes_arr == class_idx
|
|
206
206
|
gt_filter = gt_classes_arr == class_idx
|
|
207
207
|
class_pred_scores = pred_scores_arr[pred_filter]
|
|
208
|
-
tp, fp = __calculate_image_tpfp(
|
|
208
|
+
tp, fp = __calculate_image_tpfp(
|
|
209
|
+
pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], min_iou
|
|
210
|
+
)
|
|
209
211
|
ordered_class_pred_scores = -np.sort(-class_pred_scores)
|
|
210
212
|
result.append(
|
|
211
213
|
{
|
|
@@ -220,7 +222,7 @@ def eval_detections(
|
|
|
220
222
|
return result
|
|
221
223
|
|
|
222
224
|
|
|
223
|
-
@pxt.uda
|
|
225
|
+
@pxt.uda
|
|
224
226
|
class mean_ap(pxt.Aggregator):
|
|
225
227
|
"""
|
|
226
228
|
Calculates the mean average precision (mAP) over
|
|
@@ -235,7 +237,8 @@ class mean_ap(pxt.Aggregator):
|
|
|
235
237
|
|
|
236
238
|
- A `dict[int, float]` mapping each label class to an average precision (AP) value for that class.
|
|
237
239
|
"""
|
|
238
|
-
|
|
240
|
+
|
|
241
|
+
def __init__(self) -> None:
|
|
239
242
|
self.class_tpfp: dict[int, list[dict]] = defaultdict(list)
|
|
240
243
|
|
|
241
244
|
def update(self, eval_dicts: list[dict]) -> None:
|
|
@@ -247,7 +250,6 @@ class mean_ap(pxt.Aggregator):
|
|
|
247
250
|
eps = np.finfo(np.float32).eps
|
|
248
251
|
result: dict[int, float] = {}
|
|
249
252
|
for class_idx, tpfp in self.class_tpfp.items():
|
|
250
|
-
a1 = [x['tp'] for x in tpfp]
|
|
251
253
|
tp = np.concatenate([x['tp'] for x in tpfp], axis=0)
|
|
252
254
|
fp = np.concatenate([x['fp'] for x in tpfp], axis=0)
|
|
253
255
|
num_gts = np.sum([x['num_gts'] for x in tpfp])
|
|
@@ -282,22 +284,22 @@ def __create_label_colors(labels: list[Any]) -> dict[Any, str]:
|
|
|
282
284
|
label_hash = int(hashlib.md5(str(label).encode()).hexdigest(), 16)
|
|
283
285
|
hue = (label_hash % 360) / 360.0
|
|
284
286
|
rgb = colorsys.hsv_to_rgb(hue, 0.7, 0.95)
|
|
285
|
-
hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))
|
|
287
|
+
hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
|
|
286
288
|
result[label] = hex_color
|
|
287
289
|
return result
|
|
288
290
|
|
|
289
291
|
|
|
290
292
|
@pxt.udf
|
|
291
293
|
def draw_bounding_boxes(
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
294
|
+
img: PIL.Image.Image,
|
|
295
|
+
boxes: list[list[int]],
|
|
296
|
+
labels: list[Any] | None = None,
|
|
297
|
+
color: str | None = None,
|
|
298
|
+
box_colors: list[str] | None = None,
|
|
299
|
+
fill: bool = False,
|
|
300
|
+
width: int = 1,
|
|
301
|
+
font: str | None = None,
|
|
302
|
+
font_size: int | None = None,
|
|
301
303
|
) -> PIL.Image.Image:
|
|
302
304
|
"""
|
|
303
305
|
Draws bounding boxes on the given image.
|
|
@@ -338,21 +340,19 @@ def draw_bounding_boxes(
|
|
|
338
340
|
elif len(labels) != num_boxes:
|
|
339
341
|
raise ValueError('Number of boxes and labels must match')
|
|
340
342
|
|
|
341
|
-
DEFAULT_COLOR = 'white'
|
|
342
343
|
if box_colors is not None:
|
|
343
344
|
if len(box_colors) != num_boxes:
|
|
344
345
|
raise ValueError('Number of boxes and box colors must match')
|
|
346
|
+
elif color is not None:
|
|
347
|
+
box_colors = [color] * num_boxes
|
|
345
348
|
else:
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
else:
|
|
349
|
-
label_colors = __create_label_colors(labels)
|
|
350
|
-
box_colors = [label_colors[label] for label in labels]
|
|
349
|
+
label_colors = __create_label_colors(labels)
|
|
350
|
+
box_colors = [label_colors[label] for label in labels]
|
|
351
351
|
|
|
352
352
|
from PIL import ImageColor, ImageDraw, ImageFont
|
|
353
353
|
|
|
354
354
|
# set default font if not provided
|
|
355
|
-
txt_font:
|
|
355
|
+
txt_font: ImageFont.ImageFont | ImageFont.FreeTypeFont = (
|
|
356
356
|
ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size or 10)
|
|
357
357
|
)
|
|
358
358
|
|
|
@@ -366,13 +366,13 @@ def draw_bounding_boxes(
|
|
|
366
366
|
|
|
367
367
|
if fill:
|
|
368
368
|
rgb_color = ImageColor.getrgb(color)
|
|
369
|
-
fill_color = rgb_color
|
|
369
|
+
fill_color = (*rgb_color, 100) # semi-transparent
|
|
370
370
|
draw.rectangle(bbox, outline=color, width=width, fill=fill_color) # type: ignore[arg-type]
|
|
371
371
|
else:
|
|
372
372
|
draw.rectangle(bbox, outline=color, width=width) # type: ignore[arg-type]
|
|
373
373
|
|
|
374
374
|
# Now draw labels separately, so they are not obscured by the boxes
|
|
375
|
-
for
|
|
375
|
+
for bbox, label in zip(boxes, labels):
|
|
376
376
|
if label is not None:
|
|
377
377
|
label_str = str(label)
|
|
378
378
|
_, _, text_width, text_height = draw.textbbox((0, 0), label_str, font=txt_font)
|
|
@@ -394,5 +394,5 @@ def draw_bounding_boxes(
|
|
|
394
394
|
__all__ = local_public_names(__name__)
|
|
395
395
|
|
|
396
396
|
|
|
397
|
-
def __dir__():
|
|
397
|
+
def __dir__() -> list[str]:
|
|
398
398
|
return __all__
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixeltable UDFs
|
|
3
|
+
that wrap various endpoints from the Voyage AI API. In order to use them, you must
|
|
4
|
+
first `pip install voyageai` and configure your Voyage AI credentials, as described in
|
|
5
|
+
the [Working with Voyage AI](https://docs.pixeltable.com/notebooks/integrations/working-with-voyageai) tutorial.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import PIL.Image
|
|
12
|
+
|
|
13
|
+
import pixeltable as pxt
|
|
14
|
+
from pixeltable import env, type_system as ts
|
|
15
|
+
from pixeltable.func import Batch
|
|
16
|
+
from pixeltable.utils.code import local_public_names
|
|
17
|
+
|
|
18
|
+
# Default embedding dimensions for Voyage AI models
|
|
19
|
+
_embedding_dimensions_cache: dict[str, int] = {
|
|
20
|
+
'voyage-3-large': 1024,
|
|
21
|
+
'voyage-3.5': 1024,
|
|
22
|
+
'voyage-3.5-lite': 1024,
|
|
23
|
+
'voyage-code-3': 1024,
|
|
24
|
+
'voyage-finance-2': 1024,
|
|
25
|
+
'voyage-law-2': 1024,
|
|
26
|
+
'voyage-code-2': 1536,
|
|
27
|
+
'voyage-3': 1024,
|
|
28
|
+
'voyage-3-lite': 512,
|
|
29
|
+
'voyage-multilingual-2': 1024,
|
|
30
|
+
'voyage-large-2': 1536,
|
|
31
|
+
'voyage-2': 1024,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from voyageai import AsyncClient
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@env.register_client('voyage')
|
|
39
|
+
def _(api_key: str) -> 'AsyncClient':
|
|
40
|
+
from voyageai import AsyncClient
|
|
41
|
+
|
|
42
|
+
return AsyncClient(api_key=api_key)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _voyageai_client() -> 'AsyncClient':
|
|
46
|
+
return env.Env.get().get_client('voyage')
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pxt.udf(batch_size=128, resource_pool='request-rate:voyageai')
|
|
50
|
+
async def embeddings(
|
|
51
|
+
input: Batch[str],
|
|
52
|
+
*,
|
|
53
|
+
model: str,
|
|
54
|
+
input_type: Literal['query', 'document'] | None = None,
|
|
55
|
+
truncation: bool | None = None,
|
|
56
|
+
output_dimension: int | None = None,
|
|
57
|
+
output_dtype: Literal['float', 'int8', 'uint8', 'binary', 'ubinary'] | None = None,
|
|
58
|
+
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
59
|
+
"""
|
|
60
|
+
Creates an embedding vector representing the input text.
|
|
61
|
+
|
|
62
|
+
Equivalent to the Voyage AI `embeddings` API endpoint.
|
|
63
|
+
For additional details, see: <https://docs.voyageai.com/docs/embeddings>
|
|
64
|
+
|
|
65
|
+
Request throttling:
|
|
66
|
+
Applies the rate limit set in the config (section `voyageai`, key `rate_limit`). If no rate
|
|
67
|
+
limit is configured, uses a default of 600 RPM.
|
|
68
|
+
|
|
69
|
+
__Requirements:__
|
|
70
|
+
|
|
71
|
+
- `pip install voyageai`
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
input: The text to embed.
|
|
75
|
+
model: The model to use for the embedding. Recommended options: `voyage-3-large`, `voyage-3.5`,
|
|
76
|
+
`voyage-3.5-lite`, `voyage-code-3`, `voyage-finance-2`, `voyage-law-2`.
|
|
77
|
+
input_type: Type of the input text. Options: `None`, `query`, `document`.
|
|
78
|
+
When `input_type` is `None`, the embedding model directly converts the inputs into numerical vectors.
|
|
79
|
+
For retrieval/search purposes, we recommend setting this to `query` or `document` as appropriate.
|
|
80
|
+
truncation: Whether to truncate the input texts to fit within the context length. Defaults to `True`.
|
|
81
|
+
output_dimension: The number of dimensions for resulting output embeddings.
|
|
82
|
+
Most models only support a single default dimension. Models `voyage-3-large`, `voyage-3.5`,
|
|
83
|
+
`voyage-3.5-lite`, and `voyage-code-3` support: 256, 512, 1024 (default), and 2048.
|
|
84
|
+
output_dtype: The data type for the embeddings to be returned. Options: `float`, `int8`, `uint8`,
|
|
85
|
+
`binary`, `ubinary`. Only `float` is currently supported in Pixeltable.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
An array representing the application of the given embedding to `input`.
|
|
89
|
+
|
|
90
|
+
Examples:
|
|
91
|
+
Add a computed column that applies the model `voyage-3.5` to an existing
|
|
92
|
+
Pixeltable column `tbl.text` of the table `tbl`:
|
|
93
|
+
|
|
94
|
+
>>> tbl.add_computed_column(embed=embeddings(tbl.text, model='voyage-3.5', input_type='document'))
|
|
95
|
+
|
|
96
|
+
Add an embedding index to an existing column `text`, using the model `voyage-3.5`:
|
|
97
|
+
|
|
98
|
+
>>> tbl.add_embedding_index('text', string_embed=embeddings.using(model='voyage-3.5'))
|
|
99
|
+
"""
|
|
100
|
+
cl = _voyageai_client()
|
|
101
|
+
|
|
102
|
+
# Build kwargs for the API call
|
|
103
|
+
kwargs: dict[str, Any] = {}
|
|
104
|
+
if input_type is not None:
|
|
105
|
+
kwargs['input_type'] = input_type
|
|
106
|
+
if truncation is not None:
|
|
107
|
+
kwargs['truncation'] = truncation
|
|
108
|
+
if output_dimension is not None:
|
|
109
|
+
kwargs['output_dimension'] = output_dimension
|
|
110
|
+
if output_dtype is not None:
|
|
111
|
+
kwargs['output_dtype'] = output_dtype
|
|
112
|
+
|
|
113
|
+
result = await cl.embed(texts=input, model=model, **kwargs)
|
|
114
|
+
# TODO: set output dtype correctly based on output_dtype parameter
|
|
115
|
+
return [np.array(emb, dtype=np.float64) for emb in result.embeddings]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@embeddings.conditional_return_type
|
|
119
|
+
def _(
|
|
120
|
+
model: str,
|
|
121
|
+
input_type: Literal['query', 'document'] | None = None,
|
|
122
|
+
truncation: bool | None = None,
|
|
123
|
+
output_dimension: int | None = None,
|
|
124
|
+
output_dtype: Literal['float', 'int8', 'uint8', 'binary', 'ubinary'] | None = None,
|
|
125
|
+
) -> ts.ArrayType:
|
|
126
|
+
# If output_dimension is explicitly specified, use it
|
|
127
|
+
if output_dimension is not None:
|
|
128
|
+
return ts.ArrayType((output_dimension,), dtype=ts.FloatType(), nullable=False)
|
|
129
|
+
# Otherwise, look up the default for this model
|
|
130
|
+
dimensions = _embedding_dimensions_cache.get(model)
|
|
131
|
+
if dimensions is None:
|
|
132
|
+
return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
|
|
133
|
+
return ts.ArrayType((dimensions,), dtype=ts.FloatType(), nullable=False)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@pxt.udf(resource_pool='request-rate:voyageai')
|
|
137
|
+
async def rerank(
|
|
138
|
+
query: str, documents: list[str], *, model: str, top_k: int | None = None, truncation: bool = True
|
|
139
|
+
) -> dict:
|
|
140
|
+
"""
|
|
141
|
+
Reranks documents based on their relevance to a query.
|
|
142
|
+
|
|
143
|
+
Equivalent to the Voyage AI `rerank` API endpoint.
|
|
144
|
+
For additional details, see: <https://docs.voyageai.com/docs/reranker>
|
|
145
|
+
|
|
146
|
+
Request throttling:
|
|
147
|
+
Applies the rate limit set in the config (section `voyageai`, key `rate_limit`). If no rate
|
|
148
|
+
limit is configured, uses a default of 600 RPM.
|
|
149
|
+
|
|
150
|
+
__Requirements:__
|
|
151
|
+
|
|
152
|
+
- `pip install voyageai`
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
query: The query as a string.
|
|
156
|
+
documents: The documents to be reranked as a list of strings.
|
|
157
|
+
model: The model to use for reranking. Recommended options: `rerank-2.5`, `rerank-2.5-lite`.
|
|
158
|
+
top_k: The number of most relevant documents to return. If not specified, all documents
|
|
159
|
+
will be reranked and returned.
|
|
160
|
+
truncation: Whether to truncate the input to satisfy context length limits. Defaults to `True`.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
A dictionary containing:
|
|
164
|
+
- `results`: List of reranking results with `index`, `document`, and `relevance_score`
|
|
165
|
+
- `total_tokens`: The total number of tokens used
|
|
166
|
+
|
|
167
|
+
Examples:
|
|
168
|
+
Rerank similarity search results for better relevance. First, create a table with
|
|
169
|
+
an embedding index, then use a query function to retrieve candidates and rerank them:
|
|
170
|
+
|
|
171
|
+
>>> docs = pxt.create_table('docs', {'text': pxt.String})
|
|
172
|
+
>>> docs.add_computed_column(embed=embeddings(docs.text, model='voyage-3.5'))
|
|
173
|
+
>>> docs.add_embedding_index('text', embed=docs.embed)
|
|
174
|
+
>>>
|
|
175
|
+
>>> @pxt.query
|
|
176
|
+
... def get_candidates(query_text: str):
|
|
177
|
+
... sim = docs.text.similarity(query_text, embed=embeddings.using(model='voyage-3.5'))
|
|
178
|
+
... return docs.order_by(sim, asc=False).limit(20).select(docs.text)
|
|
179
|
+
>>>
|
|
180
|
+
>>> queries = pxt.create_table('queries', {'query': pxt.String})
|
|
181
|
+
>>> queries.add_computed_column(candidates=get_candidates(queries.query))
|
|
182
|
+
>>> queries.add_computed_column(
|
|
183
|
+
... reranked=rerank(queries.query, queries.candidates.text, model='rerank-2.5', top_k=5)
|
|
184
|
+
... )
|
|
185
|
+
"""
|
|
186
|
+
cl = _voyageai_client()
|
|
187
|
+
|
|
188
|
+
result = await cl.rerank(query=query, documents=documents, model=model, top_k=top_k, truncation=truncation)
|
|
189
|
+
|
|
190
|
+
# Convert the result to a dictionary format
|
|
191
|
+
return {
|
|
192
|
+
'results': [
|
|
193
|
+
{'index': r.index, 'document': r.document, 'relevance_score': r.relevance_score} for r in result.results
|
|
194
|
+
],
|
|
195
|
+
'total_tokens': result.total_tokens,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@pxt.udf(batch_size=32, resource_pool='request-rate:voyageai')
|
|
200
|
+
async def multimodal_embed(
|
|
201
|
+
text: Batch[str],
|
|
202
|
+
*,
|
|
203
|
+
model: str = 'voyage-multimodal-3',
|
|
204
|
+
input_type: Literal['query', 'document'] | None = None,
|
|
205
|
+
truncation: bool = True,
|
|
206
|
+
) -> Batch[pxt.Array[(1024,), pxt.Float]]:
|
|
207
|
+
"""
|
|
208
|
+
Creates an embedding vector for text or images using Voyage AI's multimodal model.
|
|
209
|
+
|
|
210
|
+
Equivalent to the Voyage AI `multimodal_embed` API endpoint.
|
|
211
|
+
For additional details, see: <https://docs.voyageai.com/docs/multimodal-embeddings>
|
|
212
|
+
|
|
213
|
+
Request throttling:
|
|
214
|
+
Applies the rate limit set in the config (section `voyageai`, key `rate_limit`). If no rate
|
|
215
|
+
limit is configured, uses a default of 600 RPM.
|
|
216
|
+
|
|
217
|
+
__Requirements:__
|
|
218
|
+
|
|
219
|
+
- `pip install voyageai`
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
text: The text to embed.
|
|
223
|
+
model: The model to use. Currently only `voyage-multimodal-3` is supported.
|
|
224
|
+
input_type: Type of the input. Options: `None`, `query`, `document`.
|
|
225
|
+
For retrieval/search, set to `query` or `document` as appropriate.
|
|
226
|
+
truncation: Whether to truncate inputs to fit within context length. Defaults to `True`.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
An array of 1024 floats representing the embedding.
|
|
230
|
+
|
|
231
|
+
Examples:
|
|
232
|
+
Embed a text column `description`:
|
|
233
|
+
|
|
234
|
+
>>> tbl.add_computed_column(
|
|
235
|
+
... embed=multimodal_embed(tbl.description, input_type='document')
|
|
236
|
+
... )
|
|
237
|
+
|
|
238
|
+
Add an embedding index for column `description`:
|
|
239
|
+
|
|
240
|
+
>>> tbl.add_embedding_index('description', string_embed=multimodal_embed.using(model='voyage-multimodal-3'))
|
|
241
|
+
|
|
242
|
+
Embed an image column `img`:
|
|
243
|
+
|
|
244
|
+
>>> tbl.add_computed_column(embed=multimodal_embed(tbl.img, input_type='document'))
|
|
245
|
+
"""
|
|
246
|
+
cl = _voyageai_client()
|
|
247
|
+
|
|
248
|
+
# Build inputs: each text becomes a single-element content list
|
|
249
|
+
inputs: list[list[str | PIL.Image.Image]] = [[t] for t in text]
|
|
250
|
+
|
|
251
|
+
kwargs: dict[str, Any] = {}
|
|
252
|
+
if input_type is not None:
|
|
253
|
+
kwargs['input_type'] = input_type
|
|
254
|
+
if truncation is not None:
|
|
255
|
+
kwargs['truncation'] = truncation
|
|
256
|
+
|
|
257
|
+
result = await cl.multimodal_embed(inputs=inputs, model=model, **kwargs)
|
|
258
|
+
return [np.array(emb, dtype=np.float64) for emb in result.embeddings]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@multimodal_embed.overload
|
|
262
|
+
async def _(
|
|
263
|
+
image: Batch[PIL.Image.Image],
|
|
264
|
+
*,
|
|
265
|
+
model: str = 'voyage-multimodal-3',
|
|
266
|
+
input_type: Literal['query', 'document'] | None = None,
|
|
267
|
+
truncation: bool = True,
|
|
268
|
+
) -> Batch[pxt.Array[(1024,), pxt.Float]]:
|
|
269
|
+
"""Image overload for multimodal_embed - embeds images using the multimodal model."""
|
|
270
|
+
cl = _voyageai_client()
|
|
271
|
+
|
|
272
|
+
# Build inputs: each image becomes a single-element content list
|
|
273
|
+
inputs: list[list[str | PIL.Image.Image]] = [[img] for img in image]
|
|
274
|
+
|
|
275
|
+
kwargs: dict[str, Any] = {}
|
|
276
|
+
if input_type is not None:
|
|
277
|
+
kwargs['input_type'] = input_type
|
|
278
|
+
if truncation is not None:
|
|
279
|
+
kwargs['truncation'] = truncation
|
|
280
|
+
|
|
281
|
+
result = await cl.multimodal_embed(inputs=inputs, model=model, **kwargs)
|
|
282
|
+
return [np.array(emb, dtype=np.float64) for emb in result.embeddings]
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
__all__ = local_public_names(__name__)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def __dir__() -> list[str]:
|
|
289
|
+
return __all__
|
pixeltable/functions/whisper.py
CHANGED
|
@@ -1,34 +1,36 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Pixeltable
|
|
2
|
+
Pixeltable UDFs
|
|
3
3
|
that wraps the OpenAI Whisper library.
|
|
4
4
|
|
|
5
5
|
This UDF will cause Pixeltable to invoke the relevant model locally. In order to use it, you must
|
|
6
6
|
first `pip install openai-whisper`.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING,
|
|
9
|
+
from typing import TYPE_CHECKING, Sequence
|
|
10
10
|
|
|
11
11
|
import pixeltable as pxt
|
|
12
12
|
from pixeltable.env import Env
|
|
13
|
+
from pixeltable.utils.code import local_public_names
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from whisper import Whisper # type: ignore[import-untyped]
|
|
16
17
|
|
|
18
|
+
|
|
17
19
|
@pxt.udf
|
|
18
20
|
def transcribe(
|
|
19
21
|
audio: pxt.Audio,
|
|
20
22
|
*,
|
|
21
23
|
model: str,
|
|
22
|
-
temperature:
|
|
23
|
-
compression_ratio_threshold:
|
|
24
|
-
logprob_threshold:
|
|
25
|
-
no_speech_threshold:
|
|
24
|
+
temperature: Sequence[float] | None = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
25
|
+
compression_ratio_threshold: float | None = 2.4,
|
|
26
|
+
logprob_threshold: float | None = -1.0,
|
|
27
|
+
no_speech_threshold: float | None = 0.6,
|
|
26
28
|
condition_on_previous_text: bool = True,
|
|
27
|
-
initial_prompt:
|
|
29
|
+
initial_prompt: str | None = None,
|
|
28
30
|
word_timestamps: bool = False,
|
|
29
31
|
prepend_punctuations: str = '"\'“¿([{-',
|
|
30
|
-
append_punctuations: str = '"\'.。,,!!??::”)]}、',
|
|
31
|
-
decode_options:
|
|
32
|
+
append_punctuations: str = '"\'.。,,!!??::”)]}、', # noqa: RUF001
|
|
33
|
+
decode_options: dict | None = None,
|
|
32
34
|
) -> dict:
|
|
33
35
|
"""
|
|
34
36
|
Transcribe an audio file using Whisper.
|
|
@@ -52,7 +54,7 @@ def transcribe(
|
|
|
52
54
|
Add a computed column that applies the model `base.en` to an existing Pixeltable column `tbl.audio`
|
|
53
55
|
of the table `tbl`:
|
|
54
56
|
|
|
55
|
-
>>> tbl
|
|
57
|
+
>>> tbl.add_computed_column(result=transcribe(tbl.audio, model='base.en'))
|
|
56
58
|
"""
|
|
57
59
|
Env.get().require_package('whisper')
|
|
58
60
|
Env.get().require_package('torch')
|
|
@@ -89,3 +91,10 @@ def _lookup_model(model_id: str, device: str) -> 'Whisper':
|
|
|
89
91
|
|
|
90
92
|
|
|
91
93
|
_model_cache: dict[tuple[str, str], 'Whisper'] = {}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
__all__ = local_public_names(__name__)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def __dir__() -> list[str]:
|
|
100
|
+
return __all__
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""WhisperX audio transcription and diarization functions."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
import pixeltable as pxt
|
|
8
|
+
from pixeltable.config import Config
|
|
9
|
+
from pixeltable.functions.util import resolve_torch_device
|
|
10
|
+
from pixeltable.utils.code import local_public_names
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from transformers import Wav2Vec2Model
|
|
14
|
+
from whisperx.asr import FasterWhisperPipeline # type: ignore[import-untyped]
|
|
15
|
+
from whisperx.diarize import DiarizationPipeline # type: ignore[import-untyped]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pxt.udf
|
|
19
|
+
def transcribe(
|
|
20
|
+
audio: pxt.Audio,
|
|
21
|
+
*,
|
|
22
|
+
model: str,
|
|
23
|
+
diarize: bool = False,
|
|
24
|
+
compute_type: str | None = None,
|
|
25
|
+
language: str | None = None,
|
|
26
|
+
task: str | None = None,
|
|
27
|
+
chunk_size: int | None = None,
|
|
28
|
+
alignment_model_name: str | None = None,
|
|
29
|
+
interpolate_method: str | None = None,
|
|
30
|
+
return_char_alignments: bool | None = None,
|
|
31
|
+
diarization_model_name: str | None = None,
|
|
32
|
+
num_speakers: int | None = None,
|
|
33
|
+
min_speakers: int | None = None,
|
|
34
|
+
max_speakers: int | None = None,
|
|
35
|
+
) -> dict:
|
|
36
|
+
"""
|
|
37
|
+
Transcribe an audio file using WhisperX.
|
|
38
|
+
|
|
39
|
+
This UDF runs a transcription model _locally_ using the WhisperX library,
|
|
40
|
+
equivalent to the WhisperX `transcribe` function, as described in the
|
|
41
|
+
[WhisperX library documentation](https://github.com/m-bain/whisperX).
|
|
42
|
+
|
|
43
|
+
If `diarize=True`, then speaker diarization will also be performed. Several of the UDF parameters are only valid if
|
|
44
|
+
`diarize=True`, as documented in the parameters list below.
|
|
45
|
+
|
|
46
|
+
__Requirements:__
|
|
47
|
+
|
|
48
|
+
- `pip install whisperx`
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
audio: The audio file to transcribe.
|
|
52
|
+
model: The name of the model to use for transcription.
|
|
53
|
+
diarize: Whether to perform speaker diarization.
|
|
54
|
+
compute_type: The compute type to use for the model (e.g., `'int8'`, `'float16'`). If `None`,
|
|
55
|
+
defaults to `'float16'` on CUDA devices and `'int8'` otherwise.
|
|
56
|
+
language: The language code for the transcription (e.g., `'en'` for English).
|
|
57
|
+
task: The task to perform (e.g., `'transcribe'` or `'translate'`). Defaults to `'transcribe'`.
|
|
58
|
+
chunk_size: The size of the audio chunks to process, in seconds. Defaults to `30`.
|
|
59
|
+
alignment_model_name: The name of the alignment model to use. If `None`, uses the default model for the given
|
|
60
|
+
language. Only valid if `diarize=True`.
|
|
61
|
+
interpolate_method: The method to use for interpolation of the alignment results. If not specified, uses the
|
|
62
|
+
WhisperX default (`'nearest'`). Only valid if `diarize=True`.
|
|
63
|
+
return_char_alignments: Whether to return character-level alignments. Defaults to `False`.
|
|
64
|
+
Only valid if `diarize=True`.
|
|
65
|
+
diarization_model_name: The name of the diarization model to use. Defaults to
|
|
66
|
+
`pyannote/speaker-diarization-3.1`. Only valid if `diarize=True`.
|
|
67
|
+
num_speakers: The number of speakers to expect in the audio. By default, the model with try to detect the
|
|
68
|
+
number of speakers. Only valid if `diarize=True`.
|
|
69
|
+
min_speakers: If specified, the minimum number of speakers to expect in the audio.
|
|
70
|
+
Only valid if `diarize=True`.
|
|
71
|
+
max_speakers: If specified, the maximum number of speakers to expect in the audio.
|
|
72
|
+
Only valid if `diarize=True`.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
A dictionary containing the audio transcription, diarization (if enabled), and various other metadata.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
|
|
79
|
+
of the table `tbl`:
|
|
80
|
+
|
|
81
|
+
>>> tbl.add_computed_column(result=transcribe(tbl.audio, model='tiny.en'))
|
|
82
|
+
|
|
83
|
+
Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
|
|
84
|
+
of the table `tbl`, with speaker diarization enabled, expecting at least 2 speakers:
|
|
85
|
+
|
|
86
|
+
>>> tbl.add_computed_column(
|
|
87
|
+
... result=transcribe(
|
|
88
|
+
... tbl.audio, model='tiny.en', diarize=True, min_speakers=2
|
|
89
|
+
... )
|
|
90
|
+
... )
|
|
91
|
+
"""
|
|
92
|
+
import whisperx # type: ignore[import-untyped]
|
|
93
|
+
|
|
94
|
+
if not diarize:
|
|
95
|
+
args = locals()
|
|
96
|
+
for param in (
|
|
97
|
+
'alignment_model_name',
|
|
98
|
+
'interpolate_method',
|
|
99
|
+
'return_char_alignments',
|
|
100
|
+
'diarization_model_name',
|
|
101
|
+
'num_speakers',
|
|
102
|
+
'min_speakers',
|
|
103
|
+
'max_speakers',
|
|
104
|
+
):
|
|
105
|
+
if args[param] is not None:
|
|
106
|
+
raise pxt.Error(f'`{param}` can only be set if `diarize=True`')
|
|
107
|
+
|
|
108
|
+
device = resolve_torch_device('auto', allow_mps=False)
|
|
109
|
+
compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
|
|
110
|
+
transcription_model = _lookup_transcription_model(model, device, compute_type)
|
|
111
|
+
audio_array: np.ndarray = whisperx.load_audio(audio)
|
|
112
|
+
kwargs: dict[str, Any] = {'language': language, 'task': task}
|
|
113
|
+
if chunk_size is not None:
|
|
114
|
+
kwargs['chunk_size'] = chunk_size
|
|
115
|
+
result: dict[str, Any] = transcription_model.transcribe(audio_array, batch_size=16, **kwargs)
|
|
116
|
+
|
|
117
|
+
if diarize:
|
|
118
|
+
# Alignment
|
|
119
|
+
alignment_model, metadata = _lookup_alignment_model(result['language'], device, alignment_model_name)
|
|
120
|
+
kwargs = {}
|
|
121
|
+
if interpolate_method is not None:
|
|
122
|
+
kwargs['interpolate_method'] = interpolate_method
|
|
123
|
+
if return_char_alignments is not None:
|
|
124
|
+
kwargs['return_char_alignments'] = return_char_alignments
|
|
125
|
+
result = whisperx.align(result['segments'], alignment_model, metadata, audio_array, device, **kwargs)
|
|
126
|
+
|
|
127
|
+
# Diarization
|
|
128
|
+
diarization_model = _lookup_diarization_model(device, diarization_model_name)
|
|
129
|
+
diarization_segments = diarization_model(
|
|
130
|
+
audio_array, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers
|
|
131
|
+
)
|
|
132
|
+
result = whisperx.assign_word_speakers(diarization_segments, result)
|
|
133
|
+
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _lookup_transcription_model(model: str, device: str, compute_type: str) -> 'FasterWhisperPipeline':
|
|
138
|
+
import whisperx
|
|
139
|
+
|
|
140
|
+
key = (model, device, compute_type)
|
|
141
|
+
if key not in _model_cache:
|
|
142
|
+
transcription_model = whisperx.load_model(model, device, compute_type=compute_type)
|
|
143
|
+
_model_cache[key] = transcription_model
|
|
144
|
+
return _model_cache[key]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _lookup_alignment_model(language_code: str, device: str, model_name: str | None) -> tuple['Wav2Vec2Model', dict]:
|
|
148
|
+
import whisperx
|
|
149
|
+
|
|
150
|
+
key = (language_code, device, model_name)
|
|
151
|
+
if key not in _alignment_model_cache:
|
|
152
|
+
model, metadata = whisperx.load_align_model(language_code=language_code, device=device, model_name=model_name)
|
|
153
|
+
_alignment_model_cache[key] = (model, metadata)
|
|
154
|
+
return _alignment_model_cache[key]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _lookup_diarization_model(device: str, model_name: str | None) -> 'DiarizationPipeline':
|
|
158
|
+
from whisperx.diarize import DiarizationPipeline
|
|
159
|
+
|
|
160
|
+
key = (device, model_name)
|
|
161
|
+
if key not in _diarization_model_cache:
|
|
162
|
+
auth_token = Config.get().get_string_value('auth_token', section='hf')
|
|
163
|
+
kwargs: dict[str, Any] = {'device': device, 'use_auth_token': auth_token}
|
|
164
|
+
if model_name is not None:
|
|
165
|
+
kwargs['model_name'] = model_name
|
|
166
|
+
_diarization_model_cache[key] = DiarizationPipeline(**kwargs)
|
|
167
|
+
return _diarization_model_cache[key]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
_model_cache: dict[tuple[str, str, str], 'FasterWhisperPipeline'] = {}
|
|
171
|
+
_alignment_model_cache: dict[tuple[str, str, str | None], tuple['Wav2Vec2Model', dict]] = {}
|
|
172
|
+
_diarization_model_cache: dict[tuple[str, str | None], 'DiarizationPipeline'] = {}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
__all__ = local_public_names(__name__)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def __dir__() -> list[str]:
|
|
179
|
+
return __all__
|