pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,5 +1,5 @@
1
1
  """
2
- Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
2
+ Pixeltable UDFs
3
3
  that wrap various models from the Hugging Face `transformers` package.
4
4
 
5
5
  These UDFs will cause Pixeltable to invoke the relevant models locally. In order to use them, you must
@@ -7,16 +7,22 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
7
7
  UDFs).
8
8
  """
9
9
 
10
- from typing import Any, Callable, Optional, TypeVar
10
+ from typing import Any, Callable, Literal, TypeVar
11
11
 
12
+ import av
13
+ import numpy as np
12
14
  import PIL.Image
13
15
 
14
16
  import pixeltable as pxt
15
- import pixeltable.env as env
16
17
  import pixeltable.exceptions as excs
18
+ import pixeltable.type_system as ts
19
+ from pixeltable import env
17
20
  from pixeltable.func import Batch
18
21
  from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
19
22
  from pixeltable.utils.code import local_public_names
23
+ from pixeltable.utils.local_store import TempStore
24
+
25
+ T = TypeVar('T')
20
26
 
21
27
 
22
28
  @pxt.udf(batch_size=32)
@@ -46,12 +52,11 @@ def sentence_transformer(
46
52
  Add a computed column that applies the model `all-mpnet-base-2` to an existing Pixeltable column `tbl.sentence`
47
53
  of the table `tbl`:
48
54
 
49
- >>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
55
+ >>> tbl.add_computed_column(result=sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2'))
50
56
  """
51
57
  env.Env.get().require_package('sentence_transformers')
52
58
  device = resolve_torch_device('auto')
53
- import torch
54
- from sentence_transformers import SentenceTransformer # type: ignore
59
+ from sentence_transformers import SentenceTransformer
55
60
 
56
61
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
57
62
  model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
@@ -62,21 +67,17 @@ def sentence_transformer(
62
67
 
63
68
 
64
69
  @sentence_transformer.conditional_return_type
65
- def _(model_id: str) -> pxt.ArrayType:
66
- try:
67
- from sentence_transformers import SentenceTransformer
70
+ def _(model_id: str) -> ts.ArrayType:
71
+ from sentence_transformers import SentenceTransformer
68
72
 
69
- model = _lookup_model(model_id, SentenceTransformer)
70
- return pxt.ArrayType((model.get_sentence_embedding_dimension(),), dtype=pxt.FloatType(), nullable=False)
71
- except ImportError:
72
- return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
73
+ model = _lookup_model(model_id, SentenceTransformer)
74
+ return ts.ArrayType((model.get_sentence_embedding_dimension(),), dtype=ts.FloatType(), nullable=False)
73
75
 
74
76
 
75
77
  @pxt.udf
76
78
  def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
77
79
  env.Env.get().require_package('sentence_transformers')
78
80
  device = resolve_torch_device('auto')
79
- import torch
80
81
  from sentence_transformers import SentenceTransformer
81
82
 
82
83
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
@@ -111,13 +112,12 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
111
112
  Add a computed column that applies the model `ms-marco-MiniLM-L-4-v2` to the sentences in
112
113
  columns `tbl.sentence1` and `tbl.sentence2`:
113
114
 
114
- >>> tbl['result'] = sentence_transformer(
115
- tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
116
- )
115
+ >>> tbl.add_computed_column(result=sentence_transformer(
116
+ ... tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
117
+ ... ))
117
118
  """
118
119
  env.Env.get().require_package('sentence_transformers')
119
120
  device = resolve_torch_device('auto')
120
- import torch
121
121
  from sentence_transformers import CrossEncoder
122
122
 
123
123
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
@@ -132,7 +132,6 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
132
132
  def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
133
133
  env.Env.get().require_package('sentence_transformers')
134
134
  device = resolve_torch_device('auto')
135
- import torch
136
135
  from sentence_transformers import CrossEncoder
137
136
 
138
137
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
@@ -144,9 +143,9 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
144
143
 
145
144
 
146
145
  @pxt.udf(batch_size=32)
147
- def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
146
+ def clip(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
148
147
  """
149
- Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
148
+ Computes a CLIP embedding for the specified text or image. `model_id` should be a reference to a pretrained
150
149
  [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
151
150
 
152
151
  __Requirements:__
@@ -164,12 +163,16 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), px
164
163
  Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
165
164
  Pixeltable column `tbl.text` of the table `tbl`:
166
165
 
167
- >>> tbl['result'] = clip_text(tbl.text, model_id='openai/clip-vit-base-patch32')
166
+ >>> tbl.add_computed_column(
167
+ ... result=clip(tbl.text, model_id='openai/clip-vit-base-patch32')
168
+ ... )
169
+
170
+ The same would work with an image column `tbl.image` in place of `tbl.text`.
168
171
  """
169
172
  env.Env.get().require_package('transformers')
170
173
  device = resolve_torch_device('auto')
171
174
  import torch
172
- from transformers import CLIPModel, CLIPProcessor # type: ignore
175
+ from transformers import CLIPModel, CLIPProcessor
173
176
 
174
177
  model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
175
178
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
@@ -181,29 +184,8 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), px
181
184
  return [embeddings[i] for i in range(embeddings.shape[0])]
182
185
 
183
186
 
184
- @pxt.udf(batch_size=32)
185
- def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
186
- """
187
- Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
188
- [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
189
-
190
- __Requirements:__
191
-
192
- - `pip install torch transformers`
193
-
194
- Args:
195
- image: The image to embed.
196
- model_id: The pretrained model to use for the embedding.
197
-
198
- Returns:
199
- An array containing the output of the embedding model.
200
-
201
- Examples:
202
- Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
203
- Pixeltable column `image` of the table `tbl`:
204
-
205
- >>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
206
- """
187
+ @clip.overload
188
+ def _(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
207
189
  env.Env.get().require_package('transformers')
208
190
  device = resolve_torch_device('auto')
209
191
  import torch
@@ -219,25 +201,17 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Arr
219
201
  return [embeddings[i] for i in range(embeddings.shape[0])]
220
202
 
221
203
 
222
- @clip_text.conditional_return_type
223
- @clip_image.conditional_return_type
224
- def _(model_id: str) -> pxt.ArrayType:
225
- try:
226
- from transformers import CLIPModel
204
+ @clip.conditional_return_type
205
+ def _(model_id: str) -> ts.ArrayType:
206
+ from transformers import CLIPModel
227
207
 
228
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
229
- return pxt.ArrayType((model.config.projection_dim,), dtype=pxt.FloatType(), nullable=False)
230
- except ImportError:
231
- return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
208
+ model = _lookup_model(model_id, CLIPModel.from_pretrained)
209
+ return ts.ArrayType((model.config.projection_dim,), dtype=ts.FloatType(), nullable=False)
232
210
 
233
211
 
234
212
  @pxt.udf(batch_size=4)
235
213
  def detr_for_object_detection(
236
- image: Batch[PIL.Image.Image],
237
- *,
238
- model_id: str,
239
- threshold: float = 0.5,
240
- revision: str = 'no_timm',
214
+ image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5, revision: str = 'no_timm'
241
215
  ) -> Batch[dict]:
242
216
  """
243
217
  Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
@@ -268,11 +242,11 @@ def detr_for_object_detection(
268
242
  Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
269
243
  Pixeltable column `image` of the table `tbl`:
270
244
 
271
- >>> tbl['detections'] = detr_for_object_detection(
245
+ >>> tbl.add_computed_column(detections=detr_for_object_detection(
272
246
  ... tbl.image,
273
247
  ... model_id='facebook/detr-resnet-50',
274
248
  ... threshold=0.8
275
- ... )
249
+ ... ))
276
250
  """
277
251
  env.Env.get().require_package('transformers')
278
252
  device = resolve_torch_device('auto')
@@ -305,10 +279,7 @@ def detr_for_object_detection(
305
279
 
306
280
  @pxt.udf(batch_size=4)
307
281
  def vit_for_image_classification(
308
- image: Batch[PIL.Image.Image],
309
- *,
310
- model_id: str,
311
- top_k: int = 5
282
+ image: Batch[PIL.Image.Image], *, model_id: str, top_k: int = 5
312
283
  ) -> Batch[dict[str, Any]]:
313
284
  """
314
285
  Computes image classifications for the specified image using a Vision Transformer (ViT) model.
@@ -344,11 +315,11 @@ def vit_for_image_classification(
344
315
  Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
345
316
  Pixeltable column `image` of the table `tbl`, returning the 10 most likely classes for each image:
346
317
 
347
- >>> tbl['image_class'] = vit_for_image_classification(
318
+ >>> tbl.add_computed_column(image_class=vit_for_image_classification(
348
319
  ... tbl.image,
349
320
  ... model_id='google/vit-base-patch16-224',
350
321
  ... top_k=10
351
- ... )
322
+ ... ))
352
323
  """
353
324
  env.Env.get().require_package('transformers')
354
325
  device = resolve_torch_device('auto')
@@ -380,12 +351,7 @@ def vit_for_image_classification(
380
351
 
381
352
 
382
353
  @pxt.udf
383
- def speech2text_for_conditional_generation(
384
- audio: pxt.Audio,
385
- *,
386
- model_id: str,
387
- language: Optional[str] = None,
388
- ) -> str:
354
+ def speech2text_for_conditional_generation(audio: pxt.Audio, *, model_id: str, language: str | None = None) -> str:
389
355
  """
390
356
  Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
391
357
  pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
@@ -408,19 +374,19 @@ def speech2text_for_conditional_generation(
408
374
  Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
409
375
  Pixeltable column `audio` of the table `tbl`:
410
376
 
411
- >>> tbl['transcription'] = speech2text_for_conditional_generation(
377
+ >>> tbl.add_computed_column(transcription=speech2text_for_conditional_generation(
412
378
  ... tbl.audio,
413
379
  ... model_id='facebook/s2t-small-librispeech-asr'
414
- ... )
380
+ ... ))
415
381
 
416
382
  Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
417
383
  Pixeltable column `audio` of the table `tbl`, translating the audio to French:
418
384
 
419
- >>> tbl['translation'] = speech2text_for_conditional_generation(
385
+ >>> tbl.add_computed_column(translation=speech2text_for_conditional_generation(
420
386
  ... tbl.audio,
421
387
  ... model_id='facebook/s2t-medium-mustc-multilingual-st',
422
388
  ... language='fr'
423
- ... )
389
+ ... ))
424
390
  """
425
391
  env.Env.get().require_package('transformers')
426
392
  env.Env.get().require_package('torchaudio')
@@ -428,18 +394,21 @@ def speech2text_for_conditional_generation(
428
394
  device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
429
395
  import torch
430
396
  import torchaudio # type: ignore[import-untyped]
431
- from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
397
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor, Speech2TextTokenizer
432
398
 
433
399
  model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
434
400
  processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
401
+ tokenizer = processor.tokenizer
435
402
  assert isinstance(processor, Speech2TextProcessor)
403
+ assert isinstance(tokenizer, Speech2TextTokenizer)
436
404
 
437
- if language is not None and language not in processor.tokenizer.lang_code_to_id:
405
+ if language is not None and language not in tokenizer.lang_code_to_id:
438
406
  raise excs.Error(
439
407
  f"Language code '{language}' is not supported by the model '{model_id}'. "
440
- f"Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}")
408
+ f'Supported languages are: {list(tokenizer.lang_code_to_id.keys())}'
409
+ )
441
410
 
442
- forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
411
+ forced_bos_token_id: int | None = None if language is None else tokenizer.lang_code_to_id[language]
443
412
 
444
413
  # Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
445
414
  model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
@@ -457,11 +426,7 @@ def speech2text_for_conditional_generation(
457
426
  assert waveform.dim() == 1
458
427
 
459
428
  with torch.no_grad():
460
- inputs = processor(
461
- waveform,
462
- sampling_rate=model_sampling_rate,
463
- return_tensors='pt'
464
- )
429
+ inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
465
430
  generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
466
431
 
467
432
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
@@ -484,7 +449,7 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
484
449
  Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
485
450
  is the image for which detections were computed:
486
451
 
487
- >>> tbl['detections_coco'] = detr_to_coco(tbl.image, tbl.detections)
452
+ >>> tbl.add_computed_column(detections_coco=detr_to_coco(tbl.image, tbl.detections))
488
453
  """
489
454
  bboxes, labels = detr_info['boxes'], detr_info['labels']
490
455
  annotations = [
@@ -494,14 +459,1041 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
494
459
  return {'image': {'width': image.width, 'height': image.height}, 'annotations': annotations}
495
460
 
496
461
 
497
- T = TypeVar('T')
462
+ @pxt.udf
463
+ def text_generation(text: str, *, model_id: str, model_kwargs: dict[str, Any] | None = None) -> str:
464
+ """
465
+ Generates text using a pretrained language model. `model_id` should be a reference to a pretrained
466
+ [text generation model](https://huggingface.co/models?pipeline_tag=text-generation).
498
467
 
468
+ __Requirements:__
499
469
 
500
- def _lookup_model(
470
+ - `pip install torch transformers`
471
+
472
+ Args:
473
+ text: The input text to continue/complete.
474
+ model_id: The pretrained model to use for text generation.
475
+ model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`,
476
+ `temperature`, etc. See the
477
+ [Hugging Face text_generation documentation](https://huggingface.co/docs/inference-providers/en/tasks/text-generation)
478
+ for details.
479
+
480
+ Returns:
481
+ The generated text completion.
482
+
483
+ Examples:
484
+ Add a computed column that generates text completions using the `Qwen/Qwen3-0.6B` model:
485
+
486
+ >>> tbl.add_computed_column(completion=text_generation(
487
+ ... tbl.prompt,
488
+ ... model_id='Qwen/Qwen3-0.6B',
489
+ ... model_kwargs={'temperature': 0.5, 'max_length': 150}
490
+ ... ))
491
+ """
492
+ env.Env.get().require_package('transformers')
493
+ device = resolve_torch_device('auto')
494
+ import torch
495
+ from transformers import AutoModelForCausalLM, AutoTokenizer
496
+
497
+ if model_kwargs is None:
498
+ model_kwargs = {}
499
+
500
+ model = _lookup_model(model_id, AutoModelForCausalLM.from_pretrained, device=device)
501
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
502
+
503
+ if tokenizer.pad_token is None:
504
+ tokenizer.pad_token = tokenizer.eos_token
505
+
506
+ with torch.no_grad():
507
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
508
+ outputs = model.generate(**inputs.to(device), pad_token_id=tokenizer.eos_token_id, **model_kwargs)
509
+
510
+ input_length = len(inputs['input_ids'][0])
511
+ generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
512
+ return generated_text
513
+
514
+
515
+ @pxt.udf(batch_size=16)
516
+ def text_classification(text: Batch[str], *, model_id: str, top_k: int = 5) -> Batch[list[dict[str, Any]]]:
517
+ """
518
+ Classifies text using a pretrained classification model. `model_id` should be a reference to a pretrained
519
+ [text classification model](https://huggingface.co/models?pipeline_tag=text-classification)
520
+ such as BERT, RoBERTa, or DistilBERT.
521
+
522
+ __Requirements:__
523
+
524
+ - `pip install torch transformers`
525
+
526
+ Args:
527
+ text: The text to classify.
528
+ model_id: The pretrained model to use for classification.
529
+ top_k: The number of top predictions to return.
530
+
531
+ Returns:
532
+ A dictionary containing classification results with scores, labels, and label text.
533
+
534
+ Examples:
535
+ Add a computed column for sentiment analysis:
536
+
537
+ >>> tbl.add_computed_column(sentiment=text_classification(
538
+ ... tbl.review_text,
539
+ ... model_id='cardiffnlp/twitter-roberta-base-sentiment-latest'
540
+ ... ))
541
+ """
542
+ env.Env.get().require_package('transformers')
543
+ device = resolve_torch_device('auto')
544
+ import torch
545
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
546
+
547
+ model = _lookup_model(model_id, AutoModelForSequenceClassification.from_pretrained, device=device)
548
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
549
+
550
+ with torch.no_grad():
551
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
552
+ outputs = model(**inputs.to(device))
553
+ logits = outputs.logits
554
+
555
+ probs = torch.softmax(logits, dim=-1)
556
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
557
+
558
+ results = []
559
+ for i in range(len(text)):
560
+ # Return as list of individual classification items for HuggingFace compatibility
561
+ classification_items = []
562
+ for k in range(top_k_probs.shape[1]):
563
+ classification_items.append(
564
+ {
565
+ 'label': top_k_indices[i, k].item(),
566
+ 'label_text': model.config.id2label[top_k_indices[i, k].item()],
567
+ 'score': top_k_probs[i, k].item(),
568
+ }
569
+ )
570
+ results.append(classification_items)
571
+
572
+ return results
573
+
574
+
575
+ @pxt.udf(batch_size=4)
576
+ def image_captioning(
577
+ image: Batch[PIL.Image.Image], *, model_id: str, model_kwargs: dict[str, Any] | None = None
578
+ ) -> Batch[str]:
579
+ """
580
+ Generates captions for images using a pretrained image captioning model. `model_id` should be a reference to a
581
+ pretrained [image-to-text model](https://huggingface.co/models?pipeline_tag=image-to-text) such as BLIP,
582
+ Git, or LLaVA.
583
+
584
+ __Requirements:__
585
+
586
+ - `pip install torch transformers`
587
+
588
+ Args:
589
+ image: The image to caption.
590
+ model_id: The pretrained model to use for captioning.
591
+ model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
592
+
593
+ Returns:
594
+ The generated caption text.
595
+
596
+ Examples:
597
+ Add a computed column `caption` to an existing table `tbl` that generates captions using the
598
+ `Salesforce/blip-image-captioning-base` model:
599
+
600
+ >>> tbl.add_computed_column(caption=image_captioning(
601
+ ... tbl.image,
602
+ ... model_id='Salesforce/blip-image-captioning-base',
603
+ ... model_kwargs={'max_length': 30}
604
+ ... ))
605
+ """
606
+ env.Env.get().require_package('transformers')
607
+ device = resolve_torch_device('auto')
608
+ import torch
609
+ from transformers import AutoModelForVision2Seq, AutoProcessor
610
+
611
+ if model_kwargs is None:
612
+ model_kwargs = {}
613
+
614
+ model = _lookup_model(model_id, AutoModelForVision2Seq.from_pretrained, device=device)
615
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
616
+ normalized_images = [normalize_image_mode(img) for img in image]
617
+
618
+ with torch.no_grad():
619
+ inputs = processor(images=normalized_images, return_tensors='pt')
620
+ outputs = model.generate(**inputs.to(device), **model_kwargs)
621
+
622
+ captions = processor.batch_decode(outputs, skip_special_tokens=True)
623
+ return captions
624
+
625
+
626
+ @pxt.udf(batch_size=8)
627
+ def summarization(text: Batch[str], *, model_id: str, model_kwargs: dict[str, Any] | None = None) -> Batch[str]:
628
+ """
629
+ Summarizes text using a pretrained summarization model. `model_id` should be a reference to a pretrained
630
+ [summarization model](https://huggingface.co/models?pipeline_tag=summarization) such as BART, T5, or Pegasus.
631
+
632
+ __Requirements:__
633
+
634
+ - `pip install torch transformers`
635
+
636
+ Args:
637
+ text: The text to summarize.
638
+ model_id: The pretrained model to use for summarization.
639
+ model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
640
+
641
+ Returns:
642
+ The generated summary text.
643
+
644
+ Examples:
645
+ Add a computed column that summarizes documents:
646
+
647
+ >>> tbl.add_computed_column(summary=text_summarization(
648
+ ... tbl.document_text,
649
+ ... model_id='facebook/bart-large-cnn',
650
+ ... max_length=100
651
+ ... ))
652
+ """
653
+ env.Env.get().require_package('transformers')
654
+ device = resolve_torch_device('auto')
655
+ import torch
656
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
657
+
658
+ if model_kwargs is None:
659
+ model_kwargs = {}
660
+
661
+ model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
662
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
663
+
664
+ with torch.no_grad():
665
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
666
+ outputs = model.generate(**inputs.to(device), **model_kwargs)
667
+
668
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
669
+
670
+
671
+ @pxt.udf
672
+ def token_classification(
673
+ text: str, *, model_id: str, aggregation_strategy: Literal['simple', 'first', 'average', 'max'] = 'simple'
674
+ ) -> list[dict[str, Any]]:
675
+ """
676
+ Extracts named entities from text using a pretrained named entity recognition (NER) model.
677
+ `model_id` should be a reference to a pretrained
678
+ [token classification model](https://huggingface.co/models?pipeline_tag=token-classification) for NER.
679
+
680
+ __Requirements:__
681
+
682
+ - `pip install torch transformers`
683
+
684
+ Args:
685
+ text: The text to analyze for named entities.
686
+ model_id: The pretrained model to use.
687
+ aggregation_strategy: Method used to aggregate tokens.
688
+
689
+ Returns:
690
+ A list of dictionaries containing entity information (text, label, confidence, start, end).
691
+
692
+ Examples:
693
+ Add a computed column that extracts named entities:
694
+
695
+ >>> tbl.add_computed_column(entities=token_classification(
696
+ ... tbl.text,
697
+ ... model_id='dbmdz/bert-large-cased-finetuned-conll03-english'
698
+ ... ))
699
+ """
700
+ env.Env.get().require_package('transformers')
701
+ device = resolve_torch_device('auto')
702
+ import torch
703
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
704
+
705
+ # Follow direct model loading pattern like other best practice functions
706
+ model = _lookup_model(model_id, AutoModelForTokenClassification.from_pretrained, device=device)
707
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
708
+
709
+ # Validate aggregation strategy
710
+ valid_strategies = {'simple', 'first', 'average', 'max'}
711
+ if aggregation_strategy not in valid_strategies:
712
+ raise excs.Error(
713
+ f'Invalid aggregation_strategy {aggregation_strategy!r}. Must be one of: {", ".join(valid_strategies)}'
714
+ )
715
+
716
+ with torch.no_grad():
717
+ # Tokenize with special tokens and return offsets for entity extraction
718
+ inputs = tokenizer(
719
+ text,
720
+ return_tensors='pt',
721
+ truncation=True,
722
+ max_length=512,
723
+ return_offsets_mapping=True,
724
+ add_special_tokens=True,
725
+ )
726
+
727
+ # Get model predictions
728
+ outputs = model(**{k: v.to(device) for k, v in inputs.items() if k != 'offset_mapping'})
729
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
730
+
731
+ # Get the predicted labels and confidence scores
732
+ predicted_token_classes = predictions.argmax(dim=-1).squeeze().tolist()
733
+ confidence_scores = predictions.max(dim=-1).values.squeeze().tolist()
734
+
735
+ # Handle single token case
736
+ if not isinstance(predicted_token_classes, list):
737
+ predicted_token_classes = [predicted_token_classes]
738
+ confidence_scores = [confidence_scores]
739
+
740
+ # Extract entities from predictions
741
+ entities = []
742
+ offset_mapping = inputs['offset_mapping'][0].tolist()
743
+
744
+ current_entity = None
745
+
746
+ for token_class, confidence, (start_offset, end_offset) in zip(
747
+ predicted_token_classes, confidence_scores, offset_mapping
748
+ ):
749
+ # Skip special tokens (offset is (0, 0))
750
+ if start_offset == 0 and end_offset == 0:
751
+ continue
752
+
753
+ label = model.config.id2label[token_class]
754
+
755
+ # Skip 'O' (outside) labels
756
+ if label == 'O':
757
+ if current_entity:
758
+ entities.append(current_entity)
759
+ current_entity = None
760
+ continue
761
+
762
+ # Parse BIO/BILOU tags
763
+ if label.startswith('B-') or (label.startswith('I-') and current_entity is None):
764
+ # Begin new entity
765
+ if current_entity:
766
+ entities.append(current_entity)
767
+
768
+ entity_type = label[2:] if label.startswith(('B-', 'I-')) else label
769
+ current_entity = {
770
+ 'word': text[start_offset:end_offset],
771
+ 'entity_group': entity_type,
772
+ 'score': float(confidence),
773
+ 'start': start_offset,
774
+ 'end': end_offset,
775
+ }
776
+
777
+ elif label.startswith('I-') and current_entity:
778
+ # Continue current entity
779
+ entity_type = label[2:]
780
+ if current_entity['entity_group'] == entity_type:
781
+ # Extend the current entity
782
+ current_entity['word'] = text[current_entity['start'] : end_offset]
783
+ current_entity['end'] = end_offset
784
+
785
+ # Update confidence based on aggregation strategy
786
+ if aggregation_strategy == 'average':
787
+ # Simple average (could be improved with token count weighting)
788
+ current_entity['score'] = (current_entity['score'] + float(confidence)) / 2
789
+ elif aggregation_strategy == 'max':
790
+ current_entity['score'] = max(current_entity['score'], float(confidence))
791
+ elif aggregation_strategy == 'first':
792
+ pass # Keep first confidence
793
+ # 'simple' uses the same logic as 'first'
794
+ else:
795
+ # Different entity type, start new entity
796
+ entities.append(current_entity)
797
+ current_entity = {
798
+ 'word': text[start_offset:end_offset],
799
+ 'entity_group': entity_type,
800
+ 'score': float(confidence),
801
+ 'start': start_offset,
802
+ 'end': end_offset,
803
+ }
804
+
805
+ # Don't forget the last entity
806
+ if current_entity:
807
+ entities.append(current_entity)
808
+
809
+ return entities
810
+
811
+
812
+ @pxt.udf
813
+ def question_answering(context: str, question: str, *, model_id: str) -> dict[str, Any]:
814
+ """
815
+ Answers questions based on provided context using a pretrained QA model. `model_id` should be a reference to a
816
+ pretrained [question answering model](https://huggingface.co/models?pipeline_tag=question-answering) such as
817
+ BERT or RoBERTa.
818
+
819
+ __Requirements:__
820
+
821
+ - `pip install torch transformers`
822
+
823
+ Args:
824
+ context: The context text containing the answer.
825
+ question: The question to answer.
826
+ model_id: The pretrained QA model to use.
827
+
828
+ Returns:
829
+ A dictionary containing the answer, confidence score, and start/end positions.
830
+
831
+ Examples:
832
+ Add a computed column that answers questions based on document context:
833
+
834
+ >>> tbl.add_computed_column(answer=question_answering(
835
+ ... tbl.document_text,
836
+ ... tbl.question,
837
+ ... model_id='deepset/roberta-base-squad2'
838
+ ... ))
839
+ """
840
+ env.Env.get().require_package('transformers')
841
+ device = resolve_torch_device('auto')
842
+ import torch
843
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
844
+
845
+ model = _lookup_model(model_id, AutoModelForQuestionAnswering.from_pretrained, device=device)
846
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
847
+
848
+ with torch.no_grad():
849
+ # Tokenize the question and context
850
+ inputs = tokenizer.encode_plus(
851
+ question, context, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512
852
+ )
853
+
854
+ # Get model predictions
855
+ outputs = model(**inputs.to(device))
856
+ start_scores = outputs.start_logits
857
+ end_scores = outputs.end_logits
858
+
859
+ # Find the tokens with the highest start and end scores
860
+ start_idx = torch.argmax(start_scores)
861
+ end_idx = torch.argmax(end_scores)
862
+
863
+ # Ensure end_idx >= start_idx
864
+ end_idx = torch.max(end_idx, start_idx)
865
+
866
+ # Convert token positions to string
867
+ input_ids = inputs['input_ids'][0]
868
+
869
+ # Extract answer tokens
870
+ answer_tokens = input_ids[start_idx : end_idx + 1]
871
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
872
+
873
+ # Calculate confidence score
874
+ start_probs = torch.softmax(start_scores, dim=1)
875
+ end_probs = torch.softmax(end_scores, dim=1)
876
+ confidence = float(start_probs[0][start_idx] * end_probs[0][end_idx])
877
+
878
+ return {'answer': answer.strip(), 'score': confidence, 'start': int(start_idx), 'end': int(end_idx)}
879
+
880
+
881
+ @pxt.udf(batch_size=8)
882
+ def translation(
883
+ text: Batch[str], *, model_id: str, src_lang: str | None = None, target_lang: str | None = None
884
+ ) -> Batch[str]:
885
+ """
886
+ Translates text using a pretrained translation model. `model_id` should be a reference to a pretrained
887
+ [translation model](https://huggingface.co/models?pipeline_tag=translation) such as MarianMT or T5.
888
+
889
+ __Requirements:__
890
+
891
+ - `pip install torch transformers sentencepiece`
892
+
893
+ Args:
894
+ text: The text to translate.
895
+ model_id: The pretrained translation model to use.
896
+ src_lang: Source language code (optional, can be inferred from model).
897
+ target_lang: Target language code (optional, can be inferred from model).
898
+
899
+ Returns:
900
+ The translated text.
901
+
902
+ Examples:
903
+ Add a computed column that translates text:
904
+
905
+ >>> tbl.add_computed_column(french_text=translation(
906
+ ... tbl.english_text,
907
+ ... model_id='Helsinki-NLP/opus-mt-en-fr',
908
+ ... src_lang='en',
909
+ ... target_lang='fr'
910
+ ... ))
911
+ """
912
+ env.Env.get().require_package('transformers')
913
+ device = resolve_torch_device('auto')
914
+ import torch
915
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
916
+
917
+ model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
918
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
919
+ lang_code_to_id: dict | None = getattr(tokenizer, 'lang_code_to_id', {})
920
+
921
+ # Language validation - following speech2text_for_conditional_generation pattern
922
+ if src_lang is not None and src_lang not in lang_code_to_id:
923
+ raise excs.Error(
924
+ f'Source language code {src_lang!r} is not supported by the model {model_id!r}. '
925
+ f'Supported languages are: {list(lang_code_to_id.keys())}'
926
+ )
927
+
928
+ if target_lang is not None and target_lang not in lang_code_to_id:
929
+ raise excs.Error(
930
+ f'Target language code {target_lang!r} is not supported by the model {model_id!r}. '
931
+ f'Supported languages are: {list(lang_code_to_id.keys())}'
932
+ )
933
+
934
+ with torch.no_grad():
935
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
936
+
937
+ # Set forced_bos_token_id for target language if supported
938
+ generate_kwargs = {'max_length': 512, 'num_beams': 4, 'early_stopping': True}
939
+
940
+ if target_lang is not None:
941
+ generate_kwargs['forced_bos_token_id'] = lang_code_to_id[target_lang]
942
+
943
+ outputs = model.generate(**inputs.to(device), **generate_kwargs)
944
+
945
+ # Decode all outputs at once
946
+ translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
947
+ return translations
948
+
949
+
950
+ @pxt.udf
951
+ def text_to_image(
952
+ prompt: str,
953
+ *,
954
+ model_id: str,
955
+ height: int = 512,
956
+ width: int = 512,
957
+ seed: int | None = None,
958
+ model_kwargs: dict[str, Any] | None = None,
959
+ ) -> PIL.Image.Image:
960
+ """
961
+ Generates images from text prompts using a pretrained text-to-image model. `model_id` should be a reference to a
962
+ pretrained [text-to-image model](https://huggingface.co/models?pipeline_tag=text-to-image) such as
963
+ Stable Diffusion or FLUX.
964
+
965
+ __Requirements:__
966
+
967
+ - `pip install torch transformers diffusers accelerate`
968
+
969
+ Args:
970
+ prompt: The text prompt describing the desired image.
971
+ model_id: The pretrained text-to-image model to use.
972
+ height: Height of the generated image in pixels.
973
+ width: Width of the generated image in pixels.
974
+ seed: Optional random seed for reproducibility.
975
+ model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
976
+ `guidance_scale`, or `negative_prompt`.
977
+
978
+ Returns:
979
+ The generated Image.
980
+
981
+ Examples:
982
+ Add a computed column that generates images from text prompts:
983
+
984
+ >>> tbl.add_computed_column(generated_image=text_to_image(
985
+ ... tbl.prompt,
986
+ ... model_id='stable-diffusion-v1.5/stable-diffusion-v1-5',
987
+ ... height=512,
988
+ ... width=512,
989
+ ... model_kwargs={'num_inference_steps': 25},
990
+ ... ))
991
+ """
992
+ env.Env.get().require_package('transformers')
993
+ env.Env.get().require_package('diffusers')
994
+ env.Env.get().require_package('accelerate')
995
+ device = resolve_torch_device('auto', allow_mps=False)
996
+ import torch
997
+ from diffusers import AutoPipelineForText2Image
998
+
999
+ if model_kwargs is None:
1000
+ model_kwargs = {}
1001
+
1002
+ # Parameter validation - following best practices pattern
1003
+ if height <= 0 or width <= 0:
1004
+ raise excs.Error(f'Height ({height}) and width ({width}) must be positive integers')
1005
+
1006
+ if height % 8 != 0 or width % 8 != 0:
1007
+ raise excs.Error(f'Height ({height}) and width ({width}) must be divisible by 8 for most diffusion models')
1008
+
1009
+ pipeline = _lookup_model(
1010
+ model_id,
1011
+ lambda x: AutoPipelineForText2Image.from_pretrained(
1012
+ x,
1013
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1014
+ device_map='auto' if device == 'cuda' else None,
1015
+ safety_checker=None, # Disable safety checker for performance
1016
+ requires_safety_checker=False,
1017
+ ),
1018
+ device=device,
1019
+ )
1020
+
1021
+ try:
1022
+ if device == 'cuda' and hasattr(pipeline, 'enable_model_cpu_offload'):
1023
+ pipeline.enable_model_cpu_offload()
1024
+ if hasattr(pipeline, 'enable_memory_efficient_attention'):
1025
+ pipeline.enable_memory_efficient_attention()
1026
+ except Exception:
1027
+ pass # Ignore optimization failures
1028
+
1029
+ generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
1030
+
1031
+ with torch.no_grad():
1032
+ result = pipeline(prompt, height=height, width=width, generator=generator, **model_kwargs)
1033
+ return result.images[0]
1034
+
1035
+
1036
+ @pxt.udf
1037
+ def text_to_speech(text: str, *, model_id: str, speaker_id: int | None = None, vocoder: str | None = None) -> pxt.Audio:
1038
+ """
1039
+ Converts text to speech using a pretrained TTS model. `model_id` should be a reference to a
1040
+ pretrained [text-to-speech model](https://huggingface.co/models?pipeline_tag=text-to-speech).
1041
+
1042
+ __Requirements:__
1043
+
1044
+ - `pip install torch transformers datasets soundfile`
1045
+
1046
+ Args:
1047
+ text: The text to convert to speech.
1048
+ model_id: The pretrained TTS model to use.
1049
+ speaker_id: Speaker ID for multi-speaker models.
1050
+ vocoder: Optional vocoder model for higher quality audio.
1051
+
1052
+ Returns:
1053
+ The generated audio file.
1054
+
1055
+ Examples:
1056
+ Add a computed column that converts text to speech:
1057
+
1058
+ >>> tbl.add_computed_column(audio=text_to_speech(
1059
+ ... tbl.text_content,
1060
+ ... model_id='microsoft/speecht5_tts',
1061
+ ... speaker_id=0
1062
+ ... ))
1063
+ """
1064
+ env.Env.get().require_package('transformers')
1065
+ env.Env.get().require_package('datasets')
1066
+ env.Env.get().require_package('soundfile')
1067
+ device = resolve_torch_device('auto')
1068
+ import datasets # type: ignore[import-untyped]
1069
+ import soundfile as sf # type: ignore[import-untyped]
1070
+ import torch
1071
+ from transformers import (
1072
+ AutoModelForTextToWaveform,
1073
+ AutoProcessor,
1074
+ BarkModel,
1075
+ SpeechT5ForTextToSpeech,
1076
+ SpeechT5HifiGan,
1077
+ SpeechT5Processor,
1078
+ )
1079
+
1080
+ # Model loading with error handling - following best practices pattern
1081
+ if 'speecht5' in model_id.lower():
1082
+ model = _lookup_model(model_id, SpeechT5ForTextToSpeech.from_pretrained, device=device)
1083
+ processor = _lookup_processor(model_id, SpeechT5Processor.from_pretrained)
1084
+ vocoder_model_id = vocoder or 'microsoft/speecht5_hifigan'
1085
+ vocoder_model = _lookup_model(vocoder_model_id, SpeechT5HifiGan.from_pretrained, device=device)
1086
+
1087
+ elif 'bark' in model_id.lower():
1088
+ model = _lookup_model(model_id, BarkModel.from_pretrained, device=device)
1089
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1090
+ vocoder_model = None
1091
+
1092
+ else:
1093
+ model = _lookup_model(model_id, AutoModelForTextToWaveform.from_pretrained, device=device)
1094
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1095
+ vocoder_model = None
1096
+
1097
+ # Load speaker embeddings once for SpeechT5 (following speech2text pattern)
1098
+ speaker_embeddings = None
1099
+ if 'speecht5' in model_id.lower():
1100
+ ds: datasets.Dataset
1101
+ if len(_speecht5_embeddings_dataset) == 0:
1102
+ ds = datasets.load_dataset(
1103
+ 'Matthijs/cmu-arctic-xvectors', split='validation', revision='refs/convert/parquet'
1104
+ )
1105
+ _speecht5_embeddings_dataset.append(ds)
1106
+ else:
1107
+ assert len(_speecht5_embeddings_dataset) == 1
1108
+ ds = _speecht5_embeddings_dataset[0]
1109
+ speaker_embeddings = torch.tensor(ds[speaker_id or 7306]['xvector']).unsqueeze(0).to(device)
1110
+
1111
+ with torch.no_grad():
1112
+ # Generate speech based on model type
1113
+ if 'speecht5' in model_id.lower():
1114
+ inputs = processor(text=text, return_tensors='pt').to(device)
1115
+ speech = model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=vocoder_model)
1116
+ audio_np = speech.cpu().numpy()
1117
+ sample_rate = 16000
1118
+
1119
+ elif 'bark' in model_id.lower():
1120
+ inputs = processor(text, return_tensors='pt').to(device)
1121
+ audio_array = model.generate(**inputs)
1122
+ audio_np = audio_array.cpu().numpy().squeeze()
1123
+ sample_rate = getattr(model.generation_config, 'sample_rate', 24000)
1124
+
1125
+ else:
1126
+ # Generic approach for other TTS models
1127
+ inputs = processor(text, return_tensors='pt').to(device)
1128
+ audio_output = model(**inputs)
1129
+ audio_np = audio_output.waveform.cpu().numpy().squeeze()
1130
+ sample_rate = getattr(model.config, 'sample_rate', 22050)
1131
+
1132
+ # Normalize audio - following consistent pattern
1133
+ if audio_np.dtype != np.float32:
1134
+ audio_np = audio_np.astype(np.float32)
1135
+
1136
+ if np.max(np.abs(audio_np)) > 0:
1137
+ audio_np = audio_np / np.max(np.abs(audio_np)) * 0.9
1138
+
1139
+ # Create output file
1140
+ output_filename = str(TempStore.create_path(extension='.wav'))
1141
+ sf.write(output_filename, audio_np, sample_rate, format='WAV', subtype='PCM_16')
1142
+ return output_filename
1143
+
1144
+
1145
+ @pxt.udf
1146
+ def image_to_image(
1147
+ image: PIL.Image.Image,
1148
+ prompt: str,
1149
+ *,
1150
+ model_id: str,
1151
+ seed: int | None = None,
1152
+ model_kwargs: dict[str, Any] | None = None,
1153
+ ) -> PIL.Image.Image:
1154
+ """
1155
+ Transforms input images based on text prompts using a pretrained image-to-image model.
1156
+ `model_id` should be a reference to a pretrained
1157
+ [image-to-image model](https://huggingface.co/models?pipeline_tag=image-to-image).
1158
+
1159
+ __Requirements:__
1160
+
1161
+ - `pip install torch transformers diffusers accelerate`
1162
+
1163
+ Args:
1164
+ image: The input image to transform.
1165
+ prompt: The text prompt describing the desired transformation.
1166
+ model_id: The pretrained image-to-image model to use.
1167
+ seed: Random seed for reproducibility.
1168
+ model_kwargs: Additional keyword arguments to pass to the model, such as `strength`,
1169
+ `guidance_scale`, or `num_inference_steps`.
1170
+
1171
+ Returns:
1172
+ The transformed image.
1173
+
1174
+ Examples:
1175
+ Add a computed column that transforms images based on prompts:
1176
+
1177
+ >>> tbl.add_computed_column(transformed=image_to_image(
1178
+ ... tbl.source_image,
1179
+ ... tbl.transformation_prompt,
1180
+ ... model_id='runwayml/stable-diffusion-v1-5'
1181
+ ... ))
1182
+ """
1183
+ env.Env.get().require_package('transformers')
1184
+ env.Env.get().require_package('diffusers')
1185
+ env.Env.get().require_package('accelerate')
1186
+ device = resolve_torch_device('auto')
1187
+ import torch
1188
+ from diffusers import StableDiffusionImg2ImgPipeline
1189
+
1190
+ if model_kwargs is None:
1191
+ model_kwargs = {}
1192
+
1193
+ pipe = _lookup_model(
1194
+ model_id,
1195
+ lambda x: StableDiffusionImg2ImgPipeline.from_pretrained(
1196
+ x,
1197
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1198
+ safety_checker=None,
1199
+ requires_safety_checker=False,
1200
+ ),
1201
+ device=device,
1202
+ )
1203
+
1204
+ try:
1205
+ if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
1206
+ pipe.enable_model_cpu_offload()
1207
+ if hasattr(pipe, 'enable_memory_efficient_attention'):
1208
+ pipe.enable_memory_efficient_attention()
1209
+ except Exception:
1210
+ pass # Ignore optimization failures
1211
+
1212
+ generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
1213
+
1214
+ processed_image = image.convert('RGB')
1215
+
1216
+ with torch.no_grad():
1217
+ result = pipe(prompt=prompt, image=processed_image, generator=generator, **model_kwargs)
1218
+ return result.images[0]
1219
+
1220
+
1221
+ @pxt.udf
1222
+ def automatic_speech_recognition(
1223
+ audio: pxt.Audio,
1224
+ *,
501
1225
  model_id: str,
502
- create: Callable[..., T],
503
- device: Optional[str] = None,
504
- pass_device_to_create: bool = False
1226
+ language: str | None = None,
1227
+ chunk_length_s: int | None = None,
1228
+ return_timestamps: bool = False,
1229
+ ) -> str:
1230
+ """
1231
+ Transcribes speech to text using a pretrained ASR model. `model_id` should be a reference to a
1232
+ pretrained [automatic-speech-recognition model](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition).
1233
+
1234
+ This is a **generic function** that works with many ASR model families. For production use with
1235
+ specific models, consider specialized functions like `whisper.transcribe()` or
1236
+ `speech2text_for_conditional_generation()`.
1237
+
1238
+ __Requirements:__
1239
+
1240
+ - `pip install torch transformers torchaudio`
1241
+
1242
+ __Recommended Models:__
1243
+
1244
+ - **OpenAI Whisper**: `openai/whisper-tiny.en`, `openai/whisper-small`, `openai/whisper-base`
1245
+ - **Facebook Wav2Vec2**: `facebook/wav2vec2-base-960h`, `facebook/wav2vec2-large-960h-lv60-self`
1246
+ - **Microsoft SpeechT5**: `microsoft/speecht5_asr`
1247
+ - **Meta MMS (Multilingual)**: `facebook/mms-1b-all`
1248
+
1249
+ Args:
1250
+ audio: The audio file(s) to transcribe.
1251
+ model_id: The pretrained ASR model to use.
1252
+ language: Language code for multilingual models (e.g., 'en', 'es', 'fr').
1253
+ chunk_length_s: Maximum length of audio chunks in seconds for long audio processing.
1254
+ return_timestamps: Whether to return word-level timestamps (model dependent).
1255
+
1256
+ Returns:
1257
+ The transcribed text.
1258
+
1259
+ Examples:
1260
+ Add a computed column that transcribes audio files:
1261
+
1262
+ >>> tbl.add_computed_column(transcription=automatic_speech_recognition(
1263
+ ... tbl.audio_file,
1264
+ ... model_id='openai/whisper-tiny.en' # Recommended
1265
+ ... ))
1266
+
1267
+ Transcribe with language specification:
1268
+
1269
+ >>> tbl.add_computed_column(transcription=automatic_speech_recognition(
1270
+ ... tbl.audio_file,
1271
+ ... model_id='facebook/mms-1b-all',
1272
+ ... language='en'
1273
+ ... ))
1274
+ """
1275
+ env.Env.get().require_package('transformers')
1276
+ env.Env.get().require_package('torchaudio')
1277
+ device = resolve_torch_device('auto', allow_mps=False) # Following speech2text pattern
1278
+ import torch
1279
+ import torchaudio
1280
+
1281
+ # Try to load model and processor using direct model loading - following speech2text pattern
1282
+ # Handle different ASR model types
1283
+ if 'whisper' in model_id.lower():
1284
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
1285
+
1286
+ model = _lookup_model(model_id, WhisperForConditionalGeneration.from_pretrained, device=device)
1287
+ processor = _lookup_processor(model_id, WhisperProcessor.from_pretrained)
1288
+
1289
+ # Language validation for Whisper - following speech2text pattern
1290
+ if language is not None and hasattr(processor.tokenizer, 'get_decoder_prompt_ids'):
1291
+ try:
1292
+ # Test if language is supported
1293
+ _ = processor.tokenizer.get_decoder_prompt_ids(language=language)
1294
+ except Exception:
1295
+ raise excs.Error(
1296
+ f"Language code '{language}' is not supported by Whisper model '{model_id}'. "
1297
+ f"Try common codes like 'en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh'."
1298
+ ) from None
1299
+
1300
+ elif 'wav2vec2' in model_id.lower():
1301
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
1302
+
1303
+ model = _lookup_model(model_id, Wav2Vec2ForCTC.from_pretrained, device=device)
1304
+ processor = _lookup_processor(model_id, Wav2Vec2Processor.from_pretrained)
1305
+
1306
+ elif 'speech_to_text' in model_id.lower() or 's2t' in model_id.lower():
1307
+ # Use the existing speech2text function for these models
1308
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
1309
+
1310
+ model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
1311
+ processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
1312
+
1313
+ else:
1314
+ # Generic fallback using Auto classes
1315
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
1316
+
1317
+ try:
1318
+ model = _lookup_model(model_id, AutoModelForSpeechSeq2Seq.from_pretrained, device=device)
1319
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1320
+ except Exception:
1321
+ # Fallback to CTC models
1322
+ from transformers import AutoModelForCTC
1323
+
1324
+ model = _lookup_model(model_id, AutoModelForCTC.from_pretrained, device=device)
1325
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1326
+
1327
+ # Get model's expected sampling rate - following speech2text pattern
1328
+ model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
1329
+
1330
+ # Load and preprocess audio - following speech2text pattern
1331
+ waveform, sampling_rate = torchaudio.load(audio)
1332
+
1333
+ # Resample if necessary
1334
+ if sampling_rate != model_sampling_rate:
1335
+ waveform = torchaudio.transforms.Resample(sampling_rate, model_sampling_rate)(waveform)
1336
+
1337
+ # Convert to mono if stereo
1338
+ if waveform.dim() == 2:
1339
+ waveform = torch.mean(waveform, dim=0)
1340
+ assert waveform.dim() == 1
1341
+
1342
+ with torch.no_grad():
1343
+ # Process audio with the model
1344
+ inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
1345
+
1346
+ # Handle different model types for generation
1347
+ if 'whisper' in model_id.lower():
1348
+ # Whisper-specific generation
1349
+ generate_kwargs = {}
1350
+ if language is not None:
1351
+ generate_kwargs['language'] = language
1352
+ if return_timestamps:
1353
+ generate_kwargs['return_timestamps'] = 'word' if return_timestamps else None
1354
+
1355
+ generated_ids = model.generate(**inputs.to(device), **generate_kwargs)
1356
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1357
+
1358
+ elif hasattr(model, 'generate'):
1359
+ # Seq2Seq models (Speech2Text, etc.)
1360
+ generated_ids = model.generate(**inputs.to(device))
1361
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1362
+
1363
+ else:
1364
+ # CTC models (Wav2Vec2, etc.)
1365
+ logits = model(**inputs.to(device)).logits
1366
+ predicted_ids = torch.argmax(logits, dim=-1)
1367
+ transcription = processor.batch_decode(predicted_ids)[0]
1368
+
1369
+ return transcription.strip()
1370
+
1371
+
1372
+ @pxt.udf
1373
+ def image_to_video(
1374
+ image: PIL.Image.Image,
1375
+ *,
1376
+ model_id: str,
1377
+ num_frames: int = 25,
1378
+ fps: int = 6,
1379
+ seed: int | None = None,
1380
+ model_kwargs: dict[str, Any] | None = None,
1381
+ ) -> pxt.Video:
1382
+ """
1383
+ Generates videos from input images using a pretrained image-to-video model.
1384
+ `model_id` should be a reference to a pretrained
1385
+ [image-to-video model](https://huggingface.co/models?pipeline_tag=image-to-video).
1386
+
1387
+ __Requirements:__
1388
+
1389
+ - `pip install torch transformers diffusers accelerate`
1390
+
1391
+ Args:
1392
+ image: The input image to animate into a video.
1393
+ model_id: The pretrained image-to-video model to use.
1394
+ num_frames: Number of video frames to generate.
1395
+ fps: Frames per second for the output video.
1396
+ seed: Random seed for reproducibility.
1397
+ model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
1398
+ `motion_bucket_id`, or `guidance_scale`.
1399
+
1400
+ Returns:
1401
+ The generated video file.
1402
+
1403
+ Examples:
1404
+ Add a computed column that creates videos from images:
1405
+
1406
+ >>> tbl.add_computed_column(video=image_to_video(
1407
+ ... tbl.input_image,
1408
+ ... model_id='stabilityai/stable-video-diffusion-img2vid-xt',
1409
+ ... num_frames=25,
1410
+ ... fps=7
1411
+ ... ))
1412
+ """
1413
+ env.Env.get().require_package('transformers')
1414
+ env.Env.get().require_package('diffusers')
1415
+ env.Env.get().require_package('accelerate')
1416
+ device = resolve_torch_device('auto', allow_mps=False)
1417
+ import numpy as np
1418
+ import torch
1419
+ from diffusers import StableVideoDiffusionPipeline
1420
+
1421
+ if model_kwargs is None:
1422
+ model_kwargs = {}
1423
+
1424
+ # Parameter validation - following best practices pattern
1425
+ if num_frames < 1:
1426
+ raise excs.Error(f'num_frames must be at least 1, got {num_frames}')
1427
+
1428
+ if num_frames > 25:
1429
+ raise excs.Error(f'num_frames cannot exceed 25 for most video diffusion models, got {num_frames}')
1430
+
1431
+ if fps < 1:
1432
+ raise excs.Error(f'fps must be at least 1, got {fps}')
1433
+
1434
+ if fps > 60:
1435
+ raise excs.Error(f'fps should not exceed 60 for reasonable video generation, got {fps}')
1436
+
1437
+ pipe = _lookup_model(
1438
+ model_id,
1439
+ lambda x: StableVideoDiffusionPipeline.from_pretrained(
1440
+ x,
1441
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1442
+ variant='fp16' if device == 'cuda' else None,
1443
+ ),
1444
+ device=device,
1445
+ )
1446
+
1447
+ try:
1448
+ if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
1449
+ pipe.enable_model_cpu_offload()
1450
+ if hasattr(pipe, 'enable_memory_efficient_attention'):
1451
+ pipe.enable_memory_efficient_attention()
1452
+ except Exception:
1453
+ pass # Ignore optimization failures
1454
+
1455
+ generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
1456
+
1457
+ # Ensure image is in RGB mode and proper size
1458
+ processed_image = image.convert('RGB')
1459
+ target_width, target_height = 512, 320
1460
+ processed_image = processed_image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)
1461
+
1462
+ # Generate video frames with proper error handling
1463
+ with torch.no_grad():
1464
+ result = pipe(image=processed_image, num_frames=num_frames, generator=generator, **model_kwargs)
1465
+ frames = result.frames[0]
1466
+
1467
+ # Create output video file
1468
+ output_path = str(TempStore.create_path(extension='.mp4'))
1469
+
1470
+ with av.open(output_path, mode='w') as container:
1471
+ stream = container.add_stream('h264', rate=fps)
1472
+ stream.width = target_width
1473
+ stream.height = target_height
1474
+ stream.pix_fmt = 'yuv420p'
1475
+
1476
+ # Set codec options for better compatibility
1477
+ stream.codec_context.options = {'crf': '23', 'preset': 'medium'}
1478
+
1479
+ for frame_pil in frames:
1480
+ # Convert PIL to numpy array
1481
+ frame_array = np.array(frame_pil)
1482
+ # Create av VideoFrame
1483
+ av_frame = av.VideoFrame.from_ndarray(frame_array, format='rgb24')
1484
+ # Encode and mux
1485
+ for packet in stream.encode(av_frame):
1486
+ container.mux(packet)
1487
+
1488
+ # Flush encoder
1489
+ for packet in stream.encode():
1490
+ container.mux(packet)
1491
+
1492
+ return output_path
1493
+
1494
+
1495
+ def _lookup_model(
1496
+ model_id: str, create: Callable[..., T], device: str | None = None, pass_device_to_create: bool = False
505
1497
  ) -> T:
506
1498
  from torch import nn
507
1499
 
@@ -526,12 +1518,13 @@ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
526
1518
  return _processor_cache[key]
527
1519
 
528
1520
 
529
- _model_cache: dict[tuple[str, Callable, Optional[str]], Any] = {}
1521
+ _model_cache: dict[tuple[str, Callable, str | None], Any] = {}
1522
+ _speecht5_embeddings_dataset: list[Any] = [] # contains only the speecht5 embeddings loaded by text_to_speech()
530
1523
  _processor_cache: dict[tuple[str, Callable], Any] = {}
531
1524
 
532
1525
 
533
1526
  __all__ = local_public_names(__name__)
534
1527
 
535
1528
 
536
- def __dir__():
1529
+ def __dir__() -> list[str]:
537
1530
  return __all__