pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,5 +1,5 @@
1
1
  """
2
- Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs) for Computer Vision.
2
+ Pixeltable UDFs for Computer Vision.
3
3
 
4
4
  Example:
5
5
  ```python
@@ -14,7 +14,7 @@ t.select(pxtv.draw_bounding_boxes(t.img, boxes=t.boxes, label=t.labels)).collect
14
14
  import colorsys
15
15
  import hashlib
16
16
  from collections import defaultdict
17
- from typing import Any, Optional, Union
17
+ from typing import Any
18
18
 
19
19
  import numpy as np
20
20
  import PIL.Image
@@ -205,7 +205,9 @@ def eval_detections(
205
205
  pred_filter = pred_classes_arr == class_idx
206
206
  gt_filter = gt_classes_arr == class_idx
207
207
  class_pred_scores = pred_scores_arr[pred_filter]
208
- tp, fp = __calculate_image_tpfp(pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], min_iou)
208
+ tp, fp = __calculate_image_tpfp(
209
+ pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], min_iou
210
+ )
209
211
  ordered_class_pred_scores = -np.sort(-class_pred_scores)
210
212
  result.append(
211
213
  {
@@ -220,7 +222,7 @@ def eval_detections(
220
222
  return result
221
223
 
222
224
 
223
- @pxt.uda(update_types=[pxt.JsonType()], value_type=pxt.JsonType(), allows_std_agg=True, allows_window=False)
225
+ @pxt.uda
224
226
  class mean_ap(pxt.Aggregator):
225
227
  """
226
228
  Calculates the mean average precision (mAP) over
@@ -235,7 +237,8 @@ class mean_ap(pxt.Aggregator):
235
237
 
236
238
  - A `dict[int, float]` mapping each label class to an average precision (AP) value for that class.
237
239
  """
238
- def __init__(self):
240
+
241
+ def __init__(self) -> None:
239
242
  self.class_tpfp: dict[int, list[dict]] = defaultdict(list)
240
243
 
241
244
  def update(self, eval_dicts: list[dict]) -> None:
@@ -247,7 +250,6 @@ class mean_ap(pxt.Aggregator):
247
250
  eps = np.finfo(np.float32).eps
248
251
  result: dict[int, float] = {}
249
252
  for class_idx, tpfp in self.class_tpfp.items():
250
- a1 = [x['tp'] for x in tpfp]
251
253
  tp = np.concatenate([x['tp'] for x in tpfp], axis=0)
252
254
  fp = np.concatenate([x['fp'] for x in tpfp], axis=0)
253
255
  num_gts = np.sum([x['num_gts'] for x in tpfp])
@@ -282,22 +284,22 @@ def __create_label_colors(labels: list[Any]) -> dict[Any, str]:
282
284
  label_hash = int(hashlib.md5(str(label).encode()).hexdigest(), 16)
283
285
  hue = (label_hash % 360) / 360.0
284
286
  rgb = colorsys.hsv_to_rgb(hue, 0.7, 0.95)
285
- hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))
287
+ hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
286
288
  result[label] = hex_color
287
289
  return result
288
290
 
289
291
 
290
292
  @pxt.udf
291
293
  def draw_bounding_boxes(
292
- img: PIL.Image.Image,
293
- boxes: list[list[int]],
294
- labels: Optional[list[Any]] = None,
295
- color: Optional[str] = None,
296
- box_colors: Optional[list[str]] = None,
297
- fill: bool = False,
298
- width: int = 1,
299
- font: Optional[str] = None,
300
- font_size: Optional[int] = None,
294
+ img: PIL.Image.Image,
295
+ boxes: list[list[int]],
296
+ labels: list[Any] | None = None,
297
+ color: str | None = None,
298
+ box_colors: list[str] | None = None,
299
+ fill: bool = False,
300
+ width: int = 1,
301
+ font: str | None = None,
302
+ font_size: int | None = None,
301
303
  ) -> PIL.Image.Image:
302
304
  """
303
305
  Draws bounding boxes on the given image.
@@ -338,21 +340,19 @@ def draw_bounding_boxes(
338
340
  elif len(labels) != num_boxes:
339
341
  raise ValueError('Number of boxes and labels must match')
340
342
 
341
- DEFAULT_COLOR = 'white'
342
343
  if box_colors is not None:
343
344
  if len(box_colors) != num_boxes:
344
345
  raise ValueError('Number of boxes and box colors must match')
346
+ elif color is not None:
347
+ box_colors = [color] * num_boxes
345
348
  else:
346
- if color is not None:
347
- box_colors = [color] * num_boxes
348
- else:
349
- label_colors = __create_label_colors(labels)
350
- box_colors = [label_colors[label] for label in labels]
349
+ label_colors = __create_label_colors(labels)
350
+ box_colors = [label_colors[label] for label in labels]
351
351
 
352
352
  from PIL import ImageColor, ImageDraw, ImageFont
353
353
 
354
354
  # set default font if not provided
355
- txt_font: Union[ImageFont.ImageFont, ImageFont.FreeTypeFont] = (
355
+ txt_font: ImageFont.ImageFont | ImageFont.FreeTypeFont = (
356
356
  ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size or 10)
357
357
  )
358
358
 
@@ -366,13 +366,13 @@ def draw_bounding_boxes(
366
366
 
367
367
  if fill:
368
368
  rgb_color = ImageColor.getrgb(color)
369
- fill_color = rgb_color + (100,) # semi-transparent
369
+ fill_color = (*rgb_color, 100) # semi-transparent
370
370
  draw.rectangle(bbox, outline=color, width=width, fill=fill_color) # type: ignore[arg-type]
371
371
  else:
372
372
  draw.rectangle(bbox, outline=color, width=width) # type: ignore[arg-type]
373
373
 
374
374
  # Now draw labels separately, so they are not obscured by the boxes
375
- for i, (bbox, label) in enumerate(zip(boxes, labels)):
375
+ for bbox, label in zip(boxes, labels):
376
376
  if label is not None:
377
377
  label_str = str(label)
378
378
  _, _, text_width, text_height = draw.textbbox((0, 0), label_str, font=txt_font)
@@ -394,5 +394,5 @@ def draw_bounding_boxes(
394
394
  __all__ = local_public_names(__name__)
395
395
 
396
396
 
397
- def __dir__():
397
+ def __dir__() -> list[str]:
398
398
  return __all__
@@ -0,0 +1,289 @@
1
+ """
2
+ Pixeltable UDFs
3
+ that wrap various endpoints from the Voyage AI API. In order to use them, you must
4
+ first `pip install voyageai` and configure your Voyage AI credentials, as described in
5
+ the [Working with Voyage AI](https://docs.pixeltable.com/notebooks/integrations/working-with-voyageai) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+
10
+ import numpy as np
11
+ import PIL.Image
12
+
13
+ import pixeltable as pxt
14
+ from pixeltable import env, type_system as ts
15
+ from pixeltable.func import Batch
16
+ from pixeltable.utils.code import local_public_names
17
+
18
+ # Default embedding dimensions for Voyage AI models
19
+ _embedding_dimensions_cache: dict[str, int] = {
20
+ 'voyage-3-large': 1024,
21
+ 'voyage-3.5': 1024,
22
+ 'voyage-3.5-lite': 1024,
23
+ 'voyage-code-3': 1024,
24
+ 'voyage-finance-2': 1024,
25
+ 'voyage-law-2': 1024,
26
+ 'voyage-code-2': 1536,
27
+ 'voyage-3': 1024,
28
+ 'voyage-3-lite': 512,
29
+ 'voyage-multilingual-2': 1024,
30
+ 'voyage-large-2': 1536,
31
+ 'voyage-2': 1024,
32
+ }
33
+
34
+ if TYPE_CHECKING:
35
+ from voyageai import AsyncClient
36
+
37
+
38
+ @env.register_client('voyage')
39
+ def _(api_key: str) -> 'AsyncClient':
40
+ from voyageai import AsyncClient
41
+
42
+ return AsyncClient(api_key=api_key)
43
+
44
+
45
+ def _voyageai_client() -> 'AsyncClient':
46
+ return env.Env.get().get_client('voyage')
47
+
48
+
49
+ @pxt.udf(batch_size=128, resource_pool='request-rate:voyageai')
50
+ async def embeddings(
51
+ input: Batch[str],
52
+ *,
53
+ model: str,
54
+ input_type: Literal['query', 'document'] | None = None,
55
+ truncation: bool | None = None,
56
+ output_dimension: int | None = None,
57
+ output_dtype: Literal['float', 'int8', 'uint8', 'binary', 'ubinary'] | None = None,
58
+ ) -> Batch[pxt.Array[(None,), pxt.Float]]:
59
+ """
60
+ Creates an embedding vector representing the input text.
61
+
62
+ Equivalent to the Voyage AI `embeddings` API endpoint.
63
+ For additional details, see: <https://docs.voyageai.com/docs/embeddings>
64
+
65
+ Request throttling:
66
+ Applies the rate limit set in the config (section `voyageai`, key `rate_limit`). If no rate
67
+ limit is configured, uses a default of 600 RPM.
68
+
69
+ __Requirements:__
70
+
71
+ - `pip install voyageai`
72
+
73
+ Args:
74
+ input: The text to embed.
75
+ model: The model to use for the embedding. Recommended options: `voyage-3-large`, `voyage-3.5`,
76
+ `voyage-3.5-lite`, `voyage-code-3`, `voyage-finance-2`, `voyage-law-2`.
77
+ input_type: Type of the input text. Options: `None`, `query`, `document`.
78
+ When `input_type` is `None`, the embedding model directly converts the inputs into numerical vectors.
79
+ For retrieval/search purposes, we recommend setting this to `query` or `document` as appropriate.
80
+ truncation: Whether to truncate the input texts to fit within the context length. Defaults to `True`.
81
+ output_dimension: The number of dimensions for resulting output embeddings.
82
+ Most models only support a single default dimension. Models `voyage-3-large`, `voyage-3.5`,
83
+ `voyage-3.5-lite`, and `voyage-code-3` support: 256, 512, 1024 (default), and 2048.
84
+ output_dtype: The data type for the embeddings to be returned. Options: `float`, `int8`, `uint8`,
85
+ `binary`, `ubinary`. Only `float` is currently supported in Pixeltable.
86
+
87
+ Returns:
88
+ An array representing the application of the given embedding to `input`.
89
+
90
+ Examples:
91
+ Add a computed column that applies the model `voyage-3.5` to an existing
92
+ Pixeltable column `tbl.text` of the table `tbl`:
93
+
94
+ >>> tbl.add_computed_column(embed=embeddings(tbl.text, model='voyage-3.5', input_type='document'))
95
+
96
+ Add an embedding index to an existing column `text`, using the model `voyage-3.5`:
97
+
98
+ >>> tbl.add_embedding_index('text', string_embed=embeddings.using(model='voyage-3.5'))
99
+ """
100
+ cl = _voyageai_client()
101
+
102
+ # Build kwargs for the API call
103
+ kwargs: dict[str, Any] = {}
104
+ if input_type is not None:
105
+ kwargs['input_type'] = input_type
106
+ if truncation is not None:
107
+ kwargs['truncation'] = truncation
108
+ if output_dimension is not None:
109
+ kwargs['output_dimension'] = output_dimension
110
+ if output_dtype is not None:
111
+ kwargs['output_dtype'] = output_dtype
112
+
113
+ result = await cl.embed(texts=input, model=model, **kwargs)
114
+ # TODO: set output dtype correctly based on output_dtype parameter
115
+ return [np.array(emb, dtype=np.float64) for emb in result.embeddings]
116
+
117
+
118
+ @embeddings.conditional_return_type
119
+ def _(
120
+ model: str,
121
+ input_type: Literal['query', 'document'] | None = None,
122
+ truncation: bool | None = None,
123
+ output_dimension: int | None = None,
124
+ output_dtype: Literal['float', 'int8', 'uint8', 'binary', 'ubinary'] | None = None,
125
+ ) -> ts.ArrayType:
126
+ # If output_dimension is explicitly specified, use it
127
+ if output_dimension is not None:
128
+ return ts.ArrayType((output_dimension,), dtype=ts.FloatType(), nullable=False)
129
+ # Otherwise, look up the default for this model
130
+ dimensions = _embedding_dimensions_cache.get(model)
131
+ if dimensions is None:
132
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
133
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType(), nullable=False)
134
+
135
+
136
+ @pxt.udf(resource_pool='request-rate:voyageai')
137
+ async def rerank(
138
+ query: str, documents: list[str], *, model: str, top_k: int | None = None, truncation: bool = True
139
+ ) -> dict:
140
+ """
141
+ Reranks documents based on their relevance to a query.
142
+
143
+ Equivalent to the Voyage AI `rerank` API endpoint.
144
+ For additional details, see: <https://docs.voyageai.com/docs/reranker>
145
+
146
+ Request throttling:
147
+ Applies the rate limit set in the config (section `voyageai`, key `rate_limit`). If no rate
148
+ limit is configured, uses a default of 600 RPM.
149
+
150
+ __Requirements:__
151
+
152
+ - `pip install voyageai`
153
+
154
+ Args:
155
+ query: The query as a string.
156
+ documents: The documents to be reranked as a list of strings.
157
+ model: The model to use for reranking. Recommended options: `rerank-2.5`, `rerank-2.5-lite`.
158
+ top_k: The number of most relevant documents to return. If not specified, all documents
159
+ will be reranked and returned.
160
+ truncation: Whether to truncate the input to satisfy context length limits. Defaults to `True`.
161
+
162
+ Returns:
163
+ A dictionary containing:
164
+ - `results`: List of reranking results with `index`, `document`, and `relevance_score`
165
+ - `total_tokens`: The total number of tokens used
166
+
167
+ Examples:
168
+ Rerank similarity search results for better relevance. First, create a table with
169
+ an embedding index, then use a query function to retrieve candidates and rerank them:
170
+
171
+ >>> docs = pxt.create_table('docs', {'text': pxt.String})
172
+ >>> docs.add_computed_column(embed=embeddings(docs.text, model='voyage-3.5'))
173
+ >>> docs.add_embedding_index('text', embed=docs.embed)
174
+ >>>
175
+ >>> @pxt.query
176
+ ... def get_candidates(query_text: str):
177
+ ... sim = docs.text.similarity(query_text, embed=embeddings.using(model='voyage-3.5'))
178
+ ... return docs.order_by(sim, asc=False).limit(20).select(docs.text)
179
+ >>>
180
+ >>> queries = pxt.create_table('queries', {'query': pxt.String})
181
+ >>> queries.add_computed_column(candidates=get_candidates(queries.query))
182
+ >>> queries.add_computed_column(
183
+ ... reranked=rerank(queries.query, queries.candidates.text, model='rerank-2.5', top_k=5)
184
+ ... )
185
+ """
186
+ cl = _voyageai_client()
187
+
188
+ result = await cl.rerank(query=query, documents=documents, model=model, top_k=top_k, truncation=truncation)
189
+
190
+ # Convert the result to a dictionary format
191
+ return {
192
+ 'results': [
193
+ {'index': r.index, 'document': r.document, 'relevance_score': r.relevance_score} for r in result.results
194
+ ],
195
+ 'total_tokens': result.total_tokens,
196
+ }
197
+
198
+
199
+ @pxt.udf(batch_size=32, resource_pool='request-rate:voyageai')
200
+ async def multimodal_embed(
201
+ text: Batch[str],
202
+ *,
203
+ model: str = 'voyage-multimodal-3',
204
+ input_type: Literal['query', 'document'] | None = None,
205
+ truncation: bool = True,
206
+ ) -> Batch[pxt.Array[(1024,), pxt.Float]]:
207
+ """
208
+ Creates an embedding vector for text or images using Voyage AI's multimodal model.
209
+
210
+ Equivalent to the Voyage AI `multimodal_embed` API endpoint.
211
+ For additional details, see: <https://docs.voyageai.com/docs/multimodal-embeddings>
212
+
213
+ Request throttling:
214
+ Applies the rate limit set in the config (section `voyageai`, key `rate_limit`). If no rate
215
+ limit is configured, uses a default of 600 RPM.
216
+
217
+ __Requirements:__
218
+
219
+ - `pip install voyageai`
220
+
221
+ Args:
222
+ text: The text to embed.
223
+ model: The model to use. Currently only `voyage-multimodal-3` is supported.
224
+ input_type: Type of the input. Options: `None`, `query`, `document`.
225
+ For retrieval/search, set to `query` or `document` as appropriate.
226
+ truncation: Whether to truncate inputs to fit within context length. Defaults to `True`.
227
+
228
+ Returns:
229
+ An array of 1024 floats representing the embedding.
230
+
231
+ Examples:
232
+ Embed a text column `description`:
233
+
234
+ >>> tbl.add_computed_column(
235
+ ... embed=multimodal_embed(tbl.description, input_type='document')
236
+ ... )
237
+
238
+ Add an embedding index for column `description`:
239
+
240
+ >>> tbl.add_embedding_index('description', string_embed=multimodal_embed.using(model='voyage-multimodal-3'))
241
+
242
+ Embed an image column `img`:
243
+
244
+ >>> tbl.add_computed_column(embed=multimodal_embed(tbl.img, input_type='document'))
245
+ """
246
+ cl = _voyageai_client()
247
+
248
+ # Build inputs: each text becomes a single-element content list
249
+ inputs: list[list[str | PIL.Image.Image]] = [[t] for t in text]
250
+
251
+ kwargs: dict[str, Any] = {}
252
+ if input_type is not None:
253
+ kwargs['input_type'] = input_type
254
+ if truncation is not None:
255
+ kwargs['truncation'] = truncation
256
+
257
+ result = await cl.multimodal_embed(inputs=inputs, model=model, **kwargs)
258
+ return [np.array(emb, dtype=np.float64) for emb in result.embeddings]
259
+
260
+
261
+ @multimodal_embed.overload
262
+ async def _(
263
+ image: Batch[PIL.Image.Image],
264
+ *,
265
+ model: str = 'voyage-multimodal-3',
266
+ input_type: Literal['query', 'document'] | None = None,
267
+ truncation: bool = True,
268
+ ) -> Batch[pxt.Array[(1024,), pxt.Float]]:
269
+ """Image overload for multimodal_embed - embeds images using the multimodal model."""
270
+ cl = _voyageai_client()
271
+
272
+ # Build inputs: each image becomes a single-element content list
273
+ inputs: list[list[str | PIL.Image.Image]] = [[img] for img in image]
274
+
275
+ kwargs: dict[str, Any] = {}
276
+ if input_type is not None:
277
+ kwargs['input_type'] = input_type
278
+ if truncation is not None:
279
+ kwargs['truncation'] = truncation
280
+
281
+ result = await cl.multimodal_embed(inputs=inputs, model=model, **kwargs)
282
+ return [np.array(emb, dtype=np.float64) for emb in result.embeddings]
283
+
284
+
285
+ __all__ = local_public_names(__name__)
286
+
287
+
288
+ def __dir__() -> list[str]:
289
+ return __all__
@@ -1,34 +1,36 @@
1
1
  """
2
- Pixeltable [UDF](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
2
+ Pixeltable UDFs
3
3
  that wraps the OpenAI Whisper library.
4
4
 
5
5
  This UDF will cause Pixeltable to invoke the relevant model locally. In order to use it, you must
6
6
  first `pip install openai-whisper`.
7
7
  """
8
8
 
9
- from typing import TYPE_CHECKING, Optional
9
+ from typing import TYPE_CHECKING, Sequence
10
10
 
11
11
  import pixeltable as pxt
12
12
  from pixeltable.env import Env
13
+ from pixeltable.utils.code import local_public_names
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from whisper import Whisper # type: ignore[import-untyped]
16
17
 
18
+
17
19
  @pxt.udf
18
20
  def transcribe(
19
21
  audio: pxt.Audio,
20
22
  *,
21
23
  model: str,
22
- temperature: Optional[list[float]] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
23
- compression_ratio_threshold: Optional[float] = 2.4,
24
- logprob_threshold: Optional[float] = -1.0,
25
- no_speech_threshold: Optional[float] = 0.6,
24
+ temperature: Sequence[float] | None = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
25
+ compression_ratio_threshold: float | None = 2.4,
26
+ logprob_threshold: float | None = -1.0,
27
+ no_speech_threshold: float | None = 0.6,
26
28
  condition_on_previous_text: bool = True,
27
- initial_prompt: Optional[str] = None,
29
+ initial_prompt: str | None = None,
28
30
  word_timestamps: bool = False,
29
31
  prepend_punctuations: str = '"\'“¿([{-',
30
- append_punctuations: str = '"\'.。,,!!??::”)]}、',
31
- decode_options: Optional[dict] = None,
32
+ append_punctuations: str = '"\'.。,,!!??::”)]}、', # noqa: RUF001
33
+ decode_options: dict | None = None,
32
34
  ) -> dict:
33
35
  """
34
36
  Transcribe an audio file using Whisper.
@@ -52,7 +54,7 @@ def transcribe(
52
54
  Add a computed column that applies the model `base.en` to an existing Pixeltable column `tbl.audio`
53
55
  of the table `tbl`:
54
56
 
55
- >>> tbl['result'] = transcribe(tbl.audio, model='base.en')
57
+ >>> tbl.add_computed_column(result=transcribe(tbl.audio, model='base.en'))
56
58
  """
57
59
  Env.get().require_package('whisper')
58
60
  Env.get().require_package('torch')
@@ -89,3 +91,10 @@ def _lookup_model(model_id: str, device: str) -> 'Whisper':
89
91
 
90
92
 
91
93
  _model_cache: dict[tuple[str, str], 'Whisper'] = {}
94
+
95
+
96
+ __all__ = local_public_names(__name__)
97
+
98
+
99
+ def __dir__() -> list[str]:
100
+ return __all__
@@ -0,0 +1,179 @@
1
+ """WhisperX audio transcription and diarization functions."""
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import numpy as np
6
+
7
+ import pixeltable as pxt
8
+ from pixeltable.config import Config
9
+ from pixeltable.functions.util import resolve_torch_device
10
+ from pixeltable.utils.code import local_public_names
11
+
12
+ if TYPE_CHECKING:
13
+ from transformers import Wav2Vec2Model
14
+ from whisperx.asr import FasterWhisperPipeline # type: ignore[import-untyped]
15
+ from whisperx.diarize import DiarizationPipeline # type: ignore[import-untyped]
16
+
17
+
18
+ @pxt.udf
19
+ def transcribe(
20
+ audio: pxt.Audio,
21
+ *,
22
+ model: str,
23
+ diarize: bool = False,
24
+ compute_type: str | None = None,
25
+ language: str | None = None,
26
+ task: str | None = None,
27
+ chunk_size: int | None = None,
28
+ alignment_model_name: str | None = None,
29
+ interpolate_method: str | None = None,
30
+ return_char_alignments: bool | None = None,
31
+ diarization_model_name: str | None = None,
32
+ num_speakers: int | None = None,
33
+ min_speakers: int | None = None,
34
+ max_speakers: int | None = None,
35
+ ) -> dict:
36
+ """
37
+ Transcribe an audio file using WhisperX.
38
+
39
+ This UDF runs a transcription model _locally_ using the WhisperX library,
40
+ equivalent to the WhisperX `transcribe` function, as described in the
41
+ [WhisperX library documentation](https://github.com/m-bain/whisperX).
42
+
43
+ If `diarize=True`, then speaker diarization will also be performed. Several of the UDF parameters are only valid if
44
+ `diarize=True`, as documented in the parameters list below.
45
+
46
+ __Requirements:__
47
+
48
+ - `pip install whisperx`
49
+
50
+ Args:
51
+ audio: The audio file to transcribe.
52
+ model: The name of the model to use for transcription.
53
+ diarize: Whether to perform speaker diarization.
54
+ compute_type: The compute type to use for the model (e.g., `'int8'`, `'float16'`). If `None`,
55
+ defaults to `'float16'` on CUDA devices and `'int8'` otherwise.
56
+ language: The language code for the transcription (e.g., `'en'` for English).
57
+ task: The task to perform (e.g., `'transcribe'` or `'translate'`). Defaults to `'transcribe'`.
58
+ chunk_size: The size of the audio chunks to process, in seconds. Defaults to `30`.
59
+ alignment_model_name: The name of the alignment model to use. If `None`, uses the default model for the given
60
+ language. Only valid if `diarize=True`.
61
+ interpolate_method: The method to use for interpolation of the alignment results. If not specified, uses the
62
+ WhisperX default (`'nearest'`). Only valid if `diarize=True`.
63
+ return_char_alignments: Whether to return character-level alignments. Defaults to `False`.
64
+ Only valid if `diarize=True`.
65
+ diarization_model_name: The name of the diarization model to use. Defaults to
66
+ `pyannote/speaker-diarization-3.1`. Only valid if `diarize=True`.
67
+ num_speakers: The number of speakers to expect in the audio. By default, the model with try to detect the
68
+ number of speakers. Only valid if `diarize=True`.
69
+ min_speakers: If specified, the minimum number of speakers to expect in the audio.
70
+ Only valid if `diarize=True`.
71
+ max_speakers: If specified, the maximum number of speakers to expect in the audio.
72
+ Only valid if `diarize=True`.
73
+
74
+ Returns:
75
+ A dictionary containing the audio transcription, diarization (if enabled), and various other metadata.
76
+
77
+ Examples:
78
+ Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
79
+ of the table `tbl`:
80
+
81
+ >>> tbl.add_computed_column(result=transcribe(tbl.audio, model='tiny.en'))
82
+
83
+ Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
84
+ of the table `tbl`, with speaker diarization enabled, expecting at least 2 speakers:
85
+
86
+ >>> tbl.add_computed_column(
87
+ ... result=transcribe(
88
+ ... tbl.audio, model='tiny.en', diarize=True, min_speakers=2
89
+ ... )
90
+ ... )
91
+ """
92
+ import whisperx # type: ignore[import-untyped]
93
+
94
+ if not diarize:
95
+ args = locals()
96
+ for param in (
97
+ 'alignment_model_name',
98
+ 'interpolate_method',
99
+ 'return_char_alignments',
100
+ 'diarization_model_name',
101
+ 'num_speakers',
102
+ 'min_speakers',
103
+ 'max_speakers',
104
+ ):
105
+ if args[param] is not None:
106
+ raise pxt.Error(f'`{param}` can only be set if `diarize=True`')
107
+
108
+ device = resolve_torch_device('auto', allow_mps=False)
109
+ compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
110
+ transcription_model = _lookup_transcription_model(model, device, compute_type)
111
+ audio_array: np.ndarray = whisperx.load_audio(audio)
112
+ kwargs: dict[str, Any] = {'language': language, 'task': task}
113
+ if chunk_size is not None:
114
+ kwargs['chunk_size'] = chunk_size
115
+ result: dict[str, Any] = transcription_model.transcribe(audio_array, batch_size=16, **kwargs)
116
+
117
+ if diarize:
118
+ # Alignment
119
+ alignment_model, metadata = _lookup_alignment_model(result['language'], device, alignment_model_name)
120
+ kwargs = {}
121
+ if interpolate_method is not None:
122
+ kwargs['interpolate_method'] = interpolate_method
123
+ if return_char_alignments is not None:
124
+ kwargs['return_char_alignments'] = return_char_alignments
125
+ result = whisperx.align(result['segments'], alignment_model, metadata, audio_array, device, **kwargs)
126
+
127
+ # Diarization
128
+ diarization_model = _lookup_diarization_model(device, diarization_model_name)
129
+ diarization_segments = diarization_model(
130
+ audio_array, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers
131
+ )
132
+ result = whisperx.assign_word_speakers(diarization_segments, result)
133
+
134
+ return result
135
+
136
+
137
+ def _lookup_transcription_model(model: str, device: str, compute_type: str) -> 'FasterWhisperPipeline':
138
+ import whisperx
139
+
140
+ key = (model, device, compute_type)
141
+ if key not in _model_cache:
142
+ transcription_model = whisperx.load_model(model, device, compute_type=compute_type)
143
+ _model_cache[key] = transcription_model
144
+ return _model_cache[key]
145
+
146
+
147
+ def _lookup_alignment_model(language_code: str, device: str, model_name: str | None) -> tuple['Wav2Vec2Model', dict]:
148
+ import whisperx
149
+
150
+ key = (language_code, device, model_name)
151
+ if key not in _alignment_model_cache:
152
+ model, metadata = whisperx.load_align_model(language_code=language_code, device=device, model_name=model_name)
153
+ _alignment_model_cache[key] = (model, metadata)
154
+ return _alignment_model_cache[key]
155
+
156
+
157
+ def _lookup_diarization_model(device: str, model_name: str | None) -> 'DiarizationPipeline':
158
+ from whisperx.diarize import DiarizationPipeline
159
+
160
+ key = (device, model_name)
161
+ if key not in _diarization_model_cache:
162
+ auth_token = Config.get().get_string_value('auth_token', section='hf')
163
+ kwargs: dict[str, Any] = {'device': device, 'use_auth_token': auth_token}
164
+ if model_name is not None:
165
+ kwargs['model_name'] = model_name
166
+ _diarization_model_cache[key] = DiarizationPipeline(**kwargs)
167
+ return _diarization_model_cache[key]
168
+
169
+
170
+ _model_cache: dict[tuple[str, str, str], 'FasterWhisperPipeline'] = {}
171
+ _alignment_model_cache: dict[tuple[str, str, str | None], tuple['Wav2Vec2Model', dict]] = {}
172
+ _diarization_model_cache: dict[tuple[str, str | None], 'DiarizationPipeline'] = {}
173
+
174
+
175
+ __all__ = local_public_names(__name__)
176
+
177
+
178
+ def __dir__() -> list[str]:
179
+ return __all__