pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. pixeltable/__init__.py +42 -8
  2. pixeltable/{dataframe.py → _query.py} +470 -206
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +5 -4
  5. pixeltable/catalog/catalog.py +1785 -432
  6. pixeltable/catalog/column.py +190 -113
  7. pixeltable/catalog/dir.py +2 -4
  8. pixeltable/catalog/globals.py +19 -46
  9. pixeltable/catalog/insertable_table.py +191 -98
  10. pixeltable/catalog/path.py +63 -23
  11. pixeltable/catalog/schema_object.py +11 -15
  12. pixeltable/catalog/table.py +843 -436
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +978 -657
  15. pixeltable/catalog/table_version_handle.py +72 -16
  16. pixeltable/catalog/table_version_path.py +112 -43
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +134 -90
  20. pixeltable/config.py +134 -22
  21. pixeltable/env.py +471 -157
  22. pixeltable/exceptions.py +6 -0
  23. pixeltable/exec/__init__.py +4 -1
  24. pixeltable/exec/aggregation_node.py +7 -8
  25. pixeltable/exec/cache_prefetch_node.py +83 -110
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +4 -3
  29. pixeltable/exec/data_row_batch.py +8 -65
  30. pixeltable/exec/exec_context.py +16 -4
  31. pixeltable/exec/exec_node.py +13 -36
  32. pixeltable/exec/expr_eval/evaluators.py +11 -7
  33. pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
  34. pixeltable/exec/expr_eval/globals.py +8 -5
  35. pixeltable/exec/expr_eval/row_buffer.py +1 -2
  36. pixeltable/exec/expr_eval/schedulers.py +106 -56
  37. pixeltable/exec/globals.py +35 -0
  38. pixeltable/exec/in_memory_data_node.py +19 -19
  39. pixeltable/exec/object_store_save_node.py +293 -0
  40. pixeltable/exec/row_update_node.py +16 -9
  41. pixeltable/exec/sql_node.py +351 -84
  42. pixeltable/exprs/__init__.py +1 -1
  43. pixeltable/exprs/arithmetic_expr.py +27 -22
  44. pixeltable/exprs/array_slice.py +3 -3
  45. pixeltable/exprs/column_property_ref.py +36 -23
  46. pixeltable/exprs/column_ref.py +213 -89
  47. pixeltable/exprs/comparison.py +5 -5
  48. pixeltable/exprs/compound_predicate.py +5 -4
  49. pixeltable/exprs/data_row.py +164 -54
  50. pixeltable/exprs/expr.py +70 -44
  51. pixeltable/exprs/expr_dict.py +3 -3
  52. pixeltable/exprs/expr_set.py +17 -10
  53. pixeltable/exprs/function_call.py +100 -40
  54. pixeltable/exprs/globals.py +2 -2
  55. pixeltable/exprs/in_predicate.py +4 -4
  56. pixeltable/exprs/inline_expr.py +18 -32
  57. pixeltable/exprs/is_null.py +7 -3
  58. pixeltable/exprs/json_mapper.py +8 -8
  59. pixeltable/exprs/json_path.py +56 -22
  60. pixeltable/exprs/literal.py +27 -5
  61. pixeltable/exprs/method_ref.py +2 -2
  62. pixeltable/exprs/object_ref.py +2 -2
  63. pixeltable/exprs/row_builder.py +167 -67
  64. pixeltable/exprs/rowid_ref.py +25 -10
  65. pixeltable/exprs/similarity_expr.py +58 -40
  66. pixeltable/exprs/sql_element_cache.py +4 -4
  67. pixeltable/exprs/string_op.py +5 -5
  68. pixeltable/exprs/type_cast.py +3 -5
  69. pixeltable/func/__init__.py +1 -0
  70. pixeltable/func/aggregate_function.py +8 -8
  71. pixeltable/func/callable_function.py +9 -9
  72. pixeltable/func/expr_template_function.py +17 -11
  73. pixeltable/func/function.py +18 -20
  74. pixeltable/func/function_registry.py +6 -7
  75. pixeltable/func/globals.py +2 -3
  76. pixeltable/func/mcp.py +74 -0
  77. pixeltable/func/query_template_function.py +29 -27
  78. pixeltable/func/signature.py +46 -19
  79. pixeltable/func/tools.py +31 -13
  80. pixeltable/func/udf.py +18 -20
  81. pixeltable/functions/__init__.py +16 -0
  82. pixeltable/functions/anthropic.py +123 -77
  83. pixeltable/functions/audio.py +147 -10
  84. pixeltable/functions/bedrock.py +13 -6
  85. pixeltable/functions/date.py +7 -4
  86. pixeltable/functions/deepseek.py +35 -43
  87. pixeltable/functions/document.py +81 -0
  88. pixeltable/functions/fal.py +76 -0
  89. pixeltable/functions/fireworks.py +11 -20
  90. pixeltable/functions/gemini.py +195 -39
  91. pixeltable/functions/globals.py +142 -14
  92. pixeltable/functions/groq.py +108 -0
  93. pixeltable/functions/huggingface.py +1056 -24
  94. pixeltable/functions/image.py +115 -57
  95. pixeltable/functions/json.py +1 -1
  96. pixeltable/functions/llama_cpp.py +28 -13
  97. pixeltable/functions/math.py +67 -5
  98. pixeltable/functions/mistralai.py +18 -55
  99. pixeltable/functions/net.py +70 -0
  100. pixeltable/functions/ollama.py +20 -13
  101. pixeltable/functions/openai.py +240 -226
  102. pixeltable/functions/openrouter.py +143 -0
  103. pixeltable/functions/replicate.py +4 -4
  104. pixeltable/functions/reve.py +250 -0
  105. pixeltable/functions/string.py +239 -69
  106. pixeltable/functions/timestamp.py +16 -16
  107. pixeltable/functions/together.py +24 -84
  108. pixeltable/functions/twelvelabs.py +188 -0
  109. pixeltable/functions/util.py +6 -1
  110. pixeltable/functions/uuid.py +30 -0
  111. pixeltable/functions/video.py +1515 -107
  112. pixeltable/functions/vision.py +8 -8
  113. pixeltable/functions/voyageai.py +289 -0
  114. pixeltable/functions/whisper.py +16 -8
  115. pixeltable/functions/whisperx.py +179 -0
  116. pixeltable/{ext/functions → functions}/yolox.py +2 -4
  117. pixeltable/globals.py +362 -115
  118. pixeltable/index/base.py +17 -21
  119. pixeltable/index/btree.py +28 -22
  120. pixeltable/index/embedding_index.py +100 -118
  121. pixeltable/io/__init__.py +4 -2
  122. pixeltable/io/datarows.py +8 -7
  123. pixeltable/io/external_store.py +56 -105
  124. pixeltable/io/fiftyone.py +13 -13
  125. pixeltable/io/globals.py +31 -30
  126. pixeltable/io/hf_datasets.py +61 -16
  127. pixeltable/io/label_studio.py +74 -70
  128. pixeltable/io/lancedb.py +3 -0
  129. pixeltable/io/pandas.py +21 -12
  130. pixeltable/io/parquet.py +25 -105
  131. pixeltable/io/table_data_conduit.py +250 -123
  132. pixeltable/io/utils.py +4 -4
  133. pixeltable/iterators/__init__.py +2 -1
  134. pixeltable/iterators/audio.py +26 -25
  135. pixeltable/iterators/base.py +9 -3
  136. pixeltable/iterators/document.py +112 -78
  137. pixeltable/iterators/image.py +12 -15
  138. pixeltable/iterators/string.py +11 -4
  139. pixeltable/iterators/video.py +523 -120
  140. pixeltable/metadata/__init__.py +14 -3
  141. pixeltable/metadata/converters/convert_13.py +2 -2
  142. pixeltable/metadata/converters/convert_18.py +2 -2
  143. pixeltable/metadata/converters/convert_19.py +2 -2
  144. pixeltable/metadata/converters/convert_20.py +2 -2
  145. pixeltable/metadata/converters/convert_21.py +2 -2
  146. pixeltable/metadata/converters/convert_22.py +2 -2
  147. pixeltable/metadata/converters/convert_24.py +2 -2
  148. pixeltable/metadata/converters/convert_25.py +2 -2
  149. pixeltable/metadata/converters/convert_26.py +2 -2
  150. pixeltable/metadata/converters/convert_29.py +4 -4
  151. pixeltable/metadata/converters/convert_30.py +34 -21
  152. pixeltable/metadata/converters/convert_34.py +2 -2
  153. pixeltable/metadata/converters/convert_35.py +9 -0
  154. pixeltable/metadata/converters/convert_36.py +38 -0
  155. pixeltable/metadata/converters/convert_37.py +15 -0
  156. pixeltable/metadata/converters/convert_38.py +39 -0
  157. pixeltable/metadata/converters/convert_39.py +124 -0
  158. pixeltable/metadata/converters/convert_40.py +73 -0
  159. pixeltable/metadata/converters/convert_41.py +12 -0
  160. pixeltable/metadata/converters/convert_42.py +9 -0
  161. pixeltable/metadata/converters/convert_43.py +44 -0
  162. pixeltable/metadata/converters/util.py +20 -31
  163. pixeltable/metadata/notes.py +9 -0
  164. pixeltable/metadata/schema.py +140 -53
  165. pixeltable/metadata/utils.py +74 -0
  166. pixeltable/mypy/__init__.py +3 -0
  167. pixeltable/mypy/mypy_plugin.py +123 -0
  168. pixeltable/plan.py +382 -115
  169. pixeltable/share/__init__.py +1 -1
  170. pixeltable/share/packager.py +547 -83
  171. pixeltable/share/protocol/__init__.py +33 -0
  172. pixeltable/share/protocol/common.py +165 -0
  173. pixeltable/share/protocol/operation_types.py +33 -0
  174. pixeltable/share/protocol/replica.py +119 -0
  175. pixeltable/share/publish.py +257 -59
  176. pixeltable/store.py +311 -194
  177. pixeltable/type_system.py +373 -211
  178. pixeltable/utils/__init__.py +2 -3
  179. pixeltable/utils/arrow.py +131 -17
  180. pixeltable/utils/av.py +298 -0
  181. pixeltable/utils/azure_store.py +346 -0
  182. pixeltable/utils/coco.py +6 -6
  183. pixeltable/utils/code.py +3 -3
  184. pixeltable/utils/console_output.py +4 -1
  185. pixeltable/utils/coroutine.py +6 -23
  186. pixeltable/utils/dbms.py +32 -6
  187. pixeltable/utils/description_helper.py +4 -5
  188. pixeltable/utils/documents.py +7 -18
  189. pixeltable/utils/exception_handler.py +7 -30
  190. pixeltable/utils/filecache.py +6 -6
  191. pixeltable/utils/formatter.py +86 -48
  192. pixeltable/utils/gcs_store.py +295 -0
  193. pixeltable/utils/http.py +133 -0
  194. pixeltable/utils/http_server.py +2 -3
  195. pixeltable/utils/iceberg.py +1 -2
  196. pixeltable/utils/image.py +17 -0
  197. pixeltable/utils/lancedb.py +90 -0
  198. pixeltable/utils/local_store.py +322 -0
  199. pixeltable/utils/misc.py +5 -0
  200. pixeltable/utils/object_stores.py +573 -0
  201. pixeltable/utils/pydantic.py +60 -0
  202. pixeltable/utils/pytorch.py +5 -6
  203. pixeltable/utils/s3_store.py +527 -0
  204. pixeltable/utils/sql.py +26 -0
  205. pixeltable/utils/system.py +30 -0
  206. pixeltable-0.5.7.dist-info/METADATA +579 -0
  207. pixeltable-0.5.7.dist-info/RECORD +227 -0
  208. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  209. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  210. pixeltable/__version__.py +0 -3
  211. pixeltable/catalog/named_function.py +0 -40
  212. pixeltable/ext/__init__.py +0 -17
  213. pixeltable/ext/functions/__init__.py +0 -11
  214. pixeltable/ext/functions/whisperx.py +0 -77
  215. pixeltable/utils/media_store.py +0 -77
  216. pixeltable/utils/s3.py +0 -17
  217. pixeltable-0.3.14.dist-info/METADATA +0 -434
  218. pixeltable-0.3.14.dist-info/RECORD +0 -186
  219. pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
  220. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  import builtins
2
2
  import typing
3
- from typing import Any, Callable, Optional, Union
3
+ from typing import Any, Callable
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
@@ -11,7 +11,7 @@ from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
11
11
 
12
12
 
13
13
  # TODO: remove and replace calls with astype()
14
- def cast(expr: exprs.Expr, target_type: Union[ts.ColumnType, type, _GenericAlias]) -> exprs.Expr:
14
+ def cast(expr: exprs.Expr, target_type: ts.ColumnType | type | _GenericAlias) -> exprs.Expr:
15
15
  expr.col_type = ts.ColumnType.normalize_type(target_type)
16
16
  return expr
17
17
 
@@ -19,9 +19,30 @@ def cast(expr: exprs.Expr, target_type: Union[ts.ColumnType, type, _GenericAlias
19
19
  T = typing.TypeVar('T')
20
20
 
21
21
 
22
- @func.uda(allows_window=True, type_substitutions=({T: Optional[int]}, {T: Optional[float]})) # type: ignore[misc]
22
+ @func.uda(allows_window=True, type_substitutions=({T: int | None}, {T: float | None})) # type: ignore[misc]
23
23
  class sum(func.Aggregator, typing.Generic[T]):
24
- """Sums the selected integers or floats."""
24
+ """
25
+ Aggregate function that computes the sum of non-null values of a numeric column or grouping.
26
+
27
+ Args:
28
+ val: The numeric value to add to the sum.
29
+
30
+ Returns:
31
+ The sum of the non-null values, or `None` if there are no non-null values.
32
+
33
+ Examples:
34
+ Sum the values in the `value` column of the table `tbl`:
35
+
36
+ >>> tbl.select(pxt.functions.sum(tbl.value)).collect()
37
+
38
+ Group by the `category` column and compute the sum of the `value` column for each category,
39
+ assigning the name `'category_total'` to the new column:
40
+
41
+ >>> tbl.group_by(tbl.category).select(
42
+ ... tbl.category,
43
+ ... category_total=pxt.functions.sum(tbl.value)
44
+ ... ).collect()
45
+ """
25
46
 
26
47
  def __init__(self) -> None:
27
48
  self.sum: T = None
@@ -39,7 +60,7 @@ class sum(func.Aggregator, typing.Generic[T]):
39
60
 
40
61
 
41
62
  @sum.to_sql
42
- def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
63
+ def _(val: sql.ColumnElement) -> sql.ColumnElement | None:
43
64
  # This can produce a Decimal. We are deliberately avoiding an explicit cast to a Bigint here, because that can
44
65
  # cause overflows in Postgres. We're instead doing the conversion to the target type in SqlNode.__iter__().
45
66
  return sql.sql.func.sum(val)
@@ -49,9 +70,32 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
49
70
  allows_window=True,
50
71
  # Allow counting non-null values of any type
51
72
  # TODO: should we have an "Any" type that can be used here?
52
- type_substitutions=tuple({T: Optional[t]} for t in ts.ALL_PIXELTABLE_TYPES), # type: ignore[misc]
73
+ type_substitutions=tuple({T: t | None} for t in ts.ALL_PIXELTABLE_TYPES), # type: ignore[misc]
53
74
  )
54
75
  class count(func.Aggregator, typing.Generic[T]):
76
+ """
77
+ Aggregate function that counts the number of non-null values in a column or grouping.
78
+
79
+ Args:
80
+ val: The value to count.
81
+
82
+ Returns:
83
+ The count of non-null values.
84
+
85
+ Examples:
86
+ Count the number of non-null values in the `value` column of the table `tbl`:
87
+
88
+ >>> tbl.select(pxt.functions.count(tbl.value)).collect()
89
+
90
+ Group by the `category` column and compute the count of non-null values in the `value` column
91
+ for each category, assigning the name `'category_count'` to the new column:
92
+
93
+ >>> tbl.group_by(tbl.category).select(
94
+ ... tbl.category,
95
+ ... category_count=pxt.functions.count(tbl.value)
96
+ ... ).collect()
97
+ """
98
+
55
99
  def __init__(self) -> None:
56
100
  self.count = 0
57
101
 
@@ -64,15 +108,38 @@ class count(func.Aggregator, typing.Generic[T]):
64
108
 
65
109
 
66
110
  @count.to_sql
67
- def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
111
+ def _(val: sql.ColumnElement) -> sql.ColumnElement | None:
68
112
  return sql.sql.func.count(val)
69
113
 
70
114
 
71
115
  @func.uda(
72
116
  allows_window=True,
73
- type_substitutions=tuple({T: Optional[t]} for t in (str, int, float, bool, ts.Timestamp)), # type: ignore[misc]
117
+ type_substitutions=tuple({T: t | None} for t in (str, int, float, bool, ts.Timestamp)), # type: ignore[misc]
74
118
  )
75
119
  class min(func.Aggregator, typing.Generic[T]):
120
+ """
121
+ Aggregate function that computes the minimum value in a column or grouping.
122
+
123
+ Args:
124
+ val: The value to compare.
125
+
126
+ Returns:
127
+ The minimum value, or `None` if there are no non-null values.
128
+
129
+ Examples:
130
+ Compute the minimum value in the `value` column of the table `tbl`:
131
+
132
+ >>> tbl.select(pxt.functions.min(tbl.value)).collect()
133
+
134
+ Group by the `category` column and compute the minimum value in the `value` column for each category,
135
+ assigning the name `'category_min'` to the new column:
136
+
137
+ >>> tbl.group_by(tbl.category).select(
138
+ ... tbl.category,
139
+ ... category_min=pxt.functions.min(tbl.value)
140
+ ... ).collect()
141
+ """
142
+
76
143
  def __init__(self) -> None:
77
144
  self.val: T = None
78
145
 
@@ -89,7 +156,7 @@ class min(func.Aggregator, typing.Generic[T]):
89
156
 
90
157
 
91
158
  @min.to_sql
92
- def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
159
+ def _(val: sql.ColumnElement) -> sql.ColumnElement | None:
93
160
  if val.type.python_type is bool:
94
161
  # TODO: min/max aggregation of booleans is not supported in Postgres (but it is in Python).
95
162
  # Right now we simply force the computation to be done in Python; we might consider implementing an alternate
@@ -100,9 +167,32 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
100
167
 
101
168
  @func.uda(
102
169
  allows_window=True,
103
- type_substitutions=tuple({T: Optional[t]} for t in (str, int, float, bool, ts.Timestamp)), # type: ignore[misc]
170
+ type_substitutions=tuple({T: t | None} for t in (str, int, float, bool, ts.Timestamp)), # type: ignore[misc]
104
171
  )
105
172
  class max(func.Aggregator, typing.Generic[T]):
173
+ """
174
+ Aggregate function that computes the maximum value in a column or grouping.
175
+
176
+ Args:
177
+ val: The value to compare.
178
+
179
+ Returns:
180
+ The maximum value, or `None` if there are no non-null values.
181
+
182
+ Examples:
183
+ Compute the maximum value in the `value` column of the table `tbl`:
184
+
185
+ >>> tbl.select(pxt.functions.max(tbl.value)).collect()
186
+
187
+ Group by the `category` column and compute the maximum value in the `value` column for each category,
188
+ assigning the name `'category_max'` to the new column:
189
+
190
+ >>> tbl.group_by(tbl.category).select(
191
+ ... tbl.category,
192
+ ... category_max=pxt.functions.max(tbl.value)
193
+ ... ).collect()
194
+ """
195
+
106
196
  def __init__(self) -> None:
107
197
  self.val: T = None
108
198
 
@@ -119,15 +209,38 @@ class max(func.Aggregator, typing.Generic[T]):
119
209
 
120
210
 
121
211
  @max.to_sql
122
- def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
212
+ def _(val: sql.ColumnElement) -> sql.ColumnElement | None:
123
213
  if val.type.python_type is bool:
124
214
  # TODO: see comment in @min.to_sql.
125
215
  return None
126
216
  return sql.sql.func.max(val)
127
217
 
128
218
 
129
- @func.uda(type_substitutions=({T: Optional[int]}, {T: Optional[float]})) # type: ignore[misc]
219
+ @func.uda(type_substitutions=({T: int | None}, {T: float | None})) # type: ignore[misc]
130
220
  class mean(func.Aggregator, typing.Generic[T]):
221
+ """
222
+ Aggregate function that computes the mean (average) of non-null values of a numeric column or grouping.
223
+
224
+ Args:
225
+ val: The numeric value to include in the mean.
226
+
227
+ Returns:
228
+ The mean of the non-null values, or `None` if there are no non-null values.
229
+
230
+ Examples:
231
+ Compute the mean of the values in the `value` column of the table `tbl`:
232
+
233
+ >>> tbl.select(pxt.functions.mean(tbl.value)).collect()
234
+
235
+ Group by the `category` column and compute the mean of the `value` column for each category,
236
+ assigning the name `'category_mean'` to the new column:
237
+
238
+ >>> tbl.group_by(tbl.category).select(
239
+ ... tbl.category,
240
+ ... category_mean=pxt.functions.mean(tbl.value)
241
+ ... ).collect()
242
+ """
243
+
131
244
  def __init__(self) -> None:
132
245
  self.sum: T = None
133
246
  self.count = 0
@@ -141,18 +254,33 @@ class mean(func.Aggregator, typing.Generic[T]):
141
254
  self.sum += val # type: ignore[operator]
142
255
  self.count += 1
143
256
 
144
- def value(self) -> Optional[float]: # Always a float
257
+ def value(self) -> float | None: # Always a float
145
258
  if self.count == 0:
146
259
  return None
147
260
  return self.sum / self.count # type: ignore[operator]
148
261
 
149
262
 
150
263
  @mean.to_sql
151
- def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
264
+ def _(val: sql.ColumnElement) -> sql.ColumnElement | None:
152
265
  return sql.sql.func.avg(val)
153
266
 
154
267
 
155
268
  def map(expr: exprs.Expr, fn: Callable[[exprs.Expr], Any]) -> exprs.Expr:
269
+ """
270
+ Applies a mapping function to each element of a list.
271
+
272
+ Args:
273
+ expr: The list expression to map over; must be an expression of type `pxt.Json`.
274
+ fn: An operation on Pixeltable expressions that will be applied to each element of the JSON array.
275
+
276
+ Examples:
277
+ Given a table `tbl` with a column `data` of type `pxt.Json` containing lists of integers, add a computed
278
+ column that produces new lists with each integer doubled:
279
+
280
+ >>> tbl.add_computed_column(
281
+ ... doubled=pxt.functions.map(t.data, lambda x: x * 2)
282
+ ... )
283
+ """
156
284
  target_expr: exprs.Expr
157
285
  try:
158
286
  target_expr = exprs.Expr.from_object(fn(exprs.json_path.RELATIVE_PATH_ROOT))
@@ -0,0 +1,108 @@
1
+ """
2
+ Pixeltable UDFs
3
+ that wrap various endpoints from the Groq API. In order to use them, you must
4
+ first `pip install groq` and configure your Groq credentials, as described in
5
+ the [Working with Groq](https://docs.pixeltable.com/notebooks/integrations/working-with-groq) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import pixeltable as pxt
11
+ from pixeltable import exprs
12
+ from pixeltable.env import Env, register_client
13
+ from pixeltable.utils.code import local_public_names
14
+
15
+ from .openai import _openai_response_to_pxt_tool_calls
16
+
17
+ if TYPE_CHECKING:
18
+ import groq
19
+
20
+
21
+ @register_client('groq')
22
+ def _(api_key: str) -> 'groq.AsyncGroq':
23
+ import groq
24
+
25
+ return groq.AsyncGroq(api_key=api_key)
26
+
27
+
28
+ def _groq_client() -> 'groq.AsyncGroq':
29
+ return Env.get().get_client('groq')
30
+
31
+
32
+ @pxt.udf(resource_pool='request-rate:groq')
33
+ async def chat_completions(
34
+ messages: list[dict[str, str]],
35
+ *,
36
+ model: str,
37
+ model_kwargs: dict[str, Any] | None = None,
38
+ tools: list[dict[str, Any]] | None = None,
39
+ tool_choice: dict[str, Any] | None = None,
40
+ ) -> dict:
41
+ """
42
+ Chat Completion API.
43
+
44
+ Equivalent to the Groq `chat/completions` API endpoint.
45
+ For additional details, see: <https://console.groq.com/docs/api-reference#chat-create>
46
+
47
+ Request throttling:
48
+ Applies the rate limit set in the config (section `groq`, key `rate_limit`). If no rate
49
+ limit is configured, uses a default of 600 RPM.
50
+
51
+ __Requirements:__
52
+
53
+ - `pip install groq`
54
+
55
+ Args:
56
+ messages: A list of messages comprising the conversation so far.
57
+ model: ID of the model to use. (See overview here: <https://console.groq.com/docs/models>)
58
+ model_kwargs: Additional keyword args for the Groq `chat/completions` API.
59
+ For details on the available parameters, see: <https://console.groq.com/docs/api-reference#chat-create>
60
+
61
+ Returns:
62
+ A dictionary containing the response and other metadata.
63
+
64
+ Examples:
65
+ Add a computed column that applies the model `llama-3.1-8b-instant`
66
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
67
+
68
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
69
+ ... tbl.add_computed_column(response=chat_completions(messages, model='llama-3.1-8b-instant'))
70
+ """
71
+ if model_kwargs is None:
72
+ model_kwargs = {}
73
+
74
+ Env.get().require_package('groq')
75
+
76
+ if tools is not None:
77
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
78
+
79
+ if tool_choice is not None:
80
+ if tool_choice['auto']:
81
+ model_kwargs['tool_choice'] = 'auto'
82
+ elif tool_choice['required']:
83
+ model_kwargs['tool_choice'] = 'required'
84
+ else:
85
+ assert tool_choice['tool'] is not None
86
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
87
+
88
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
89
+ model_kwargs['parallel_tool_calls'] = False
90
+
91
+ result = await _groq_client().chat.completions.create(
92
+ messages=messages, # type: ignore[arg-type]
93
+ model=model,
94
+ **model_kwargs,
95
+ )
96
+ return result.model_dump()
97
+
98
+
99
+ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDict:
100
+ """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
101
+ return tools._invoke(_openai_response_to_pxt_tool_calls(response))
102
+
103
+
104
+ __all__ = local_public_names(__name__)
105
+
106
+
107
+ def __dir__() -> list[str]:
108
+ return __all__