pixeltable 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (42) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/insertable_table.py +9 -7
  4. pixeltable/catalog/table.py +18 -5
  5. pixeltable/catalog/table_version.py +1 -1
  6. pixeltable/catalog/view.py +1 -1
  7. pixeltable/dataframe.py +1 -1
  8. pixeltable/env.py +140 -40
  9. pixeltable/exceptions.py +12 -5
  10. pixeltable/exec/component_iteration_node.py +63 -42
  11. pixeltable/exprs/__init__.py +1 -2
  12. pixeltable/exprs/expr.py +5 -6
  13. pixeltable/exprs/function_call.py +8 -10
  14. pixeltable/exprs/inline_expr.py +200 -0
  15. pixeltable/exprs/json_path.py +3 -6
  16. pixeltable/ext/functions/whisperx.py +2 -0
  17. pixeltable/ext/functions/yolox.py +5 -3
  18. pixeltable/functions/huggingface.py +89 -12
  19. pixeltable/functions/image.py +3 -3
  20. pixeltable/functions/together.py +37 -16
  21. pixeltable/functions/vision.py +43 -21
  22. pixeltable/functions/whisper.py +3 -0
  23. pixeltable/globals.py +7 -1
  24. pixeltable/io/globals.py +1 -1
  25. pixeltable/io/hf_datasets.py +3 -3
  26. pixeltable/iterators/document.py +1 -1
  27. pixeltable/metadata/__init__.py +1 -1
  28. pixeltable/metadata/converters/convert_18.py +1 -1
  29. pixeltable/metadata/converters/convert_20.py +56 -0
  30. pixeltable/metadata/converters/util.py +29 -4
  31. pixeltable/metadata/notes.py +1 -0
  32. pixeltable/tool/create_test_db_dump.py +15 -4
  33. pixeltable/type_system.py +3 -1
  34. pixeltable/utils/filecache.py +126 -79
  35. pixeltable-0.2.20.dist-info/LICENSE +201 -0
  36. {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/METADATA +16 -6
  37. {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/RECORD +39 -39
  38. pixeltable/exprs/inline_array.py +0 -117
  39. pixeltable/exprs/inline_dict.py +0 -104
  40. pixeltable-0.2.18.dist-info/LICENSE +0 -18
  41. {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/WHEEL +0 -0
  42. {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import json
5
5
  import sys
6
- from typing import Optional, Any
6
+ from typing import Any, Optional
7
7
 
8
8
  import sqlalchemy as sql
9
9
 
@@ -11,10 +11,10 @@ import pixeltable.catalog as catalog
11
11
  import pixeltable.exceptions as excs
12
12
  import pixeltable.func as func
13
13
  import pixeltable.type_system as ts
14
+
14
15
  from .data_row import DataRow
15
16
  from .expr import Expr
16
- from .inline_array import InlineArray
17
- from .inline_dict import InlineDict
17
+ from .inline_expr import InlineDict, InlineList
18
18
  from .row_builder import RowBuilder
19
19
  from .rowid_ref import RowidRef
20
20
  from .sql_element_cache import SqlElementCache
@@ -53,7 +53,7 @@ class FunctionCall(Expr):
53
53
  super().__init__(fn.call_return_type(bound_args))
54
54
  self.fn = fn
55
55
  self.is_method_call = is_method_call
56
- self.normalize_args(signature, bound_args)
56
+ self.normalize_args(fn.name, signature, bound_args)
57
57
 
58
58
  self.agg_init_args = {}
59
59
  if self.is_agg_fn_call:
@@ -143,7 +143,7 @@ class FunctionCall(Expr):
143
143
  return super().default_column_name()
144
144
 
145
145
  @classmethod
146
- def normalize_args(cls, signature: func.Signature, bound_args: dict[str, Any]) -> None:
146
+ def normalize_args(cls, fn_name: str, signature: func.Signature, bound_args: dict[str, Any]) -> None:
147
147
  """Converts all args to Exprs and checks that they are compatible with signature.
148
148
 
149
149
  Updates bound_args in place, where necessary.
@@ -163,9 +163,7 @@ class FunctionCall(Expr):
163
163
 
164
164
  if isinstance(arg, list) or isinstance(arg, tuple):
165
165
  try:
166
- # If the column type is JsonType, force the literal to be JSON
167
- is_json = is_var_param or (param.col_type is not None and param.col_type.is_json_type())
168
- arg = InlineArray(arg, force_json=is_json)
166
+ arg = InlineList(arg)
169
167
  bound_args[param_name] = arg
170
168
  continue
171
169
  except excs.Error:
@@ -177,7 +175,7 @@ class FunctionCall(Expr):
177
175
  try:
178
176
  _ = json.dumps(arg)
179
177
  except TypeError:
180
- raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg}')
178
+ raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg} (of type {type(arg)})')
181
179
  if arg is not None:
182
180
  try:
183
181
  param_type = param.col_type
@@ -215,7 +213,7 @@ class FunctionCall(Expr):
215
213
  or (arg.col_type.is_json_type() and param.col_type.is_scalar_type())
216
214
  ):
217
215
  raise excs.Error(
218
- f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
216
+ f'Parameter {param_name} (in function {fn_name}): argument type {arg.col_type} does not match parameter type '
219
217
  f'{param.col_type}')
220
218
 
221
219
  def _equals(self, other: FunctionCall) -> bool:
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import Any, Iterable, Optional
5
+
6
+ import numpy as np
7
+ import sqlalchemy as sql
8
+
9
+ import pixeltable.exceptions as excs
10
+ import pixeltable.type_system as ts
11
+
12
+ from .data_row import DataRow
13
+ from .expr import Expr
14
+ from .literal import Literal
15
+ from .row_builder import RowBuilder
16
+ from .sql_element_cache import SqlElementCache
17
+
18
+
19
+ class InlineArray(Expr):
20
+ """
21
+ Array 'literal' which can use Exprs as values.
22
+ """
23
+
24
+ def __init__(self, elements: Iterable):
25
+ exprs = []
26
+ for el in elements:
27
+ if isinstance(el, Expr):
28
+ exprs.append(el)
29
+ elif isinstance(el, list) or isinstance(el, tuple):
30
+ exprs.append(InlineArray(el))
31
+ else:
32
+ exprs.append(Literal(el))
33
+
34
+ inferred_element_type: Optional[ts.ColumnType] = ts.InvalidType()
35
+ for i, expr in enumerate(exprs):
36
+ supertype = inferred_element_type.supertype(expr.col_type)
37
+ if supertype is None:
38
+ raise excs.Error(
39
+ f'Could not infer element type of array: element of type `{expr.col_type}` at index {i} '
40
+ f'is not compatible with type `{inferred_element_type}` of preceding elements'
41
+ )
42
+ inferred_element_type = supertype
43
+
44
+ if inferred_element_type.is_scalar_type():
45
+ col_type = ts.ArrayType((len(exprs),), inferred_element_type)
46
+ elif inferred_element_type.is_array_type():
47
+ assert isinstance(inferred_element_type, ts.ArrayType)
48
+ col_type = ts.ArrayType(
49
+ (len(exprs), *inferred_element_type.shape),
50
+ ts.ColumnType.make_type(inferred_element_type.dtype)
51
+ )
52
+ else:
53
+ raise excs.Error(f'Element type is not a valid dtype for an array: {inferred_element_type}')
54
+
55
+ super().__init__(col_type)
56
+ self.components.extend(exprs)
57
+ self.id = self._create_id()
58
+
59
+ def __str__(self) -> str:
60
+ elem_strs = [str(expr) for expr in self.components]
61
+ return f'[{", ".join(elem_strs)}]'
62
+
63
+ def _equals(self, _: InlineArray) -> bool:
64
+ return True # Always true if components match
65
+
66
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
67
+ return None
68
+
69
+ def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
70
+ data_row[self.slot_idx] = np.array([data_row[el.slot_idx] for el in self.components])
71
+
72
+ def _as_dict(self) -> dict:
73
+ return super()._as_dict()
74
+
75
+ @classmethod
76
+ def _from_dict(cls, _: dict, components: list[Expr]) -> Expr:
77
+ try:
78
+ return cls(components)
79
+ except excs.Error:
80
+ # For legacy compatibility reasons, we need to try constructing as an `InlineList`.
81
+ # This is because in schema versions <= 19, `InlineArray` was serialized incorrectly, and
82
+ # there is no way to determine the correct expression type until the subexpressions are
83
+ # loaded and their types are known.
84
+ return InlineList(components)
85
+
86
+
87
+ class InlineList(Expr):
88
+ """
89
+ List 'literal' which can use Exprs as values.
90
+ """
91
+
92
+ def __init__(self, elements: Iterable):
93
+ exprs = []
94
+ for el in elements:
95
+ if isinstance(el, Expr):
96
+ exprs.append(el)
97
+ elif isinstance(el, list) or isinstance(el, tuple):
98
+ exprs.append(InlineList(el))
99
+ elif isinstance(el, dict):
100
+ exprs.append(InlineDict(el))
101
+ else:
102
+ exprs.append(Literal(el))
103
+
104
+ super().__init__(ts.JsonType())
105
+ self.components.extend(exprs)
106
+ self.id = self._create_id()
107
+
108
+ def __str__(self) -> str:
109
+ elem_strs = [str(expr) for expr in self.components]
110
+ return f'[{", ".join(elem_strs)}]'
111
+
112
+ def _equals(self, _: InlineList) -> bool:
113
+ return True # Always true if components match
114
+
115
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
116
+ return None
117
+
118
+ def eval(self, data_row: DataRow, _: RowBuilder) -> None:
119
+ data_row[self.slot_idx] = [data_row[el.slot_idx] for el in self.components]
120
+
121
+ def _as_dict(self) -> dict:
122
+ return super()._as_dict()
123
+
124
+ @classmethod
125
+ def _from_dict(cls, _: dict, components: list[Expr]) -> Expr:
126
+ return cls(components)
127
+
128
+
129
+ class InlineDict(Expr):
130
+ """
131
+ Dictionary 'literal' which can use Exprs as values.
132
+ """
133
+
134
+ keys: list[str]
135
+
136
+ def __init__(self, d: dict[str, Any]):
137
+ self.keys = []
138
+ exprs: list[Expr] = []
139
+ for key, val in d.items():
140
+ if not isinstance(key, str):
141
+ raise excs.Error(f'Dictionary requires string keys; {key} has type {type(key)}')
142
+ self.keys.append(key)
143
+ if isinstance(val, Expr):
144
+ exprs.append(val)
145
+ elif isinstance(val, dict):
146
+ exprs.append(InlineDict(val))
147
+ elif isinstance(val, list) or isinstance(val, tuple):
148
+ exprs.append(InlineList(val))
149
+ else:
150
+ exprs.append(Literal(val))
151
+
152
+ super().__init__(ts.JsonType())
153
+ self.components.extend(exprs)
154
+ self.id = self._create_id()
155
+
156
+ def __str__(self) -> str:
157
+ item_strs = list(f"'{key}': {str(expr)}" for key, expr in zip(self.keys, self.components))
158
+ return '{' + ', '.join(item_strs) + '}'
159
+
160
+ def _equals(self, other: InlineDict) -> bool:
161
+ # The dict values are just the components, which have already been checked
162
+ return self.keys == other.keys
163
+
164
+ def _id_attrs(self) -> list[tuple[str, Any]]:
165
+ return super()._id_attrs() + [('keys', self.keys)]
166
+
167
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
168
+ return None
169
+
170
+ def eval(self, data_row: DataRow, _: RowBuilder) -> None:
171
+ assert len(self.keys) == len(self.components)
172
+ data_row[self.slot_idx] = {
173
+ key: data_row[expr.slot_idx]
174
+ for key, expr in zip(self.keys, self.components)
175
+ }
176
+
177
+ def to_kwargs(self) -> dict[str, Any]:
178
+ """Deconstructs this expression into a dictionary by recursively unwrapping all Literals,
179
+ InlineDicts, and InlineLists."""
180
+ return InlineDict._to_kwarg_element(self)
181
+
182
+ @classmethod
183
+ def _to_kwarg_element(cls, expr: Expr) -> Any:
184
+ if isinstance(expr, Literal):
185
+ return expr.val
186
+ if isinstance(expr, InlineDict):
187
+ return {key: cls._to_kwarg_element(val) for key, val in zip(expr.keys, expr.components)}
188
+ if isinstance(expr, InlineList):
189
+ return [cls._to_kwarg_element(el) for el in expr.components]
190
+ return expr
191
+
192
+ def _as_dict(self) -> dict[str, Any]:
193
+ return {'keys': self.keys, **super()._as_dict()}
194
+
195
+ @classmethod
196
+ def _from_dict(cls, d: dict, components: list[Expr]) -> Expr:
197
+ assert 'keys' in d
198
+ assert len(d['keys']) == len(components)
199
+ arg = dict(zip(d['keys'], components))
200
+ return InlineDict(arg)
@@ -105,12 +105,9 @@ class JsonPath(Expr):
105
105
  return JsonPath(self._anchor, self.path_elements + [name])
106
106
 
107
107
  def __getitem__(self, index: object) -> 'JsonPath':
108
- if isinstance(index, str):
109
- if index != '*':
110
- raise excs.Error(f'Invalid json list index: {index}')
111
- elif not isinstance(index, (int, slice)):
112
- raise excs.Error(f'Invalid json list index: {index}')
113
- return JsonPath(self._anchor, self.path_elements + [index])
108
+ if isinstance(index, (int, slice, str)):
109
+ return JsonPath(self._anchor, self.path_elements + [index])
110
+ raise excs.Error(f'Invalid json list index: {index}')
114
111
 
115
112
  def __rshift__(self, other: object) -> 'JsonMapper':
116
113
  rhs_expr = Expr.from_object(other)
@@ -19,6 +19,8 @@ def transcribe(
19
19
  equivalent to the WhisperX `transcribe` function, as described in the
20
20
  [WhisperX library documentation](https://github.com/m-bain/whisperX).
21
21
 
22
+ WhisperX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
23
+
22
24
  __Requirements:__
23
25
 
24
26
  - `pip install whisperx`
@@ -26,8 +26,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
26
26
  Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
27
27
  defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
28
28
 
29
- YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
30
- intended for use in production applications.
29
+ YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
31
30
 
32
31
  __Requirements__:
33
32
 
@@ -79,6 +78,8 @@ def yolo_to_coco(detections: dict) -> list:
79
78
  """
80
79
  Converts the output of a YOLOX object detection model to COCO format.
81
80
 
81
+ YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
82
+
82
83
  Args:
83
84
  detections: The output of a YOLOX object detection model, as returned by `yolox`.
84
85
 
@@ -89,7 +90,8 @@ def yolo_to_coco(detections: dict) -> list:
89
90
  Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
90
91
  is the image for which detections were computed:
91
92
 
92
- >>> tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
+ >>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
94
+ ... tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
95
  """
94
96
  bboxes, labels = detections['bboxes'], detections['labels']
95
97
  num_annotations = len(detections['bboxes'])
@@ -185,7 +185,7 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndar
185
185
 
186
186
  Examples:
187
187
  Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
188
- Pixeltable column `tbl.image` of the table `tbl`:
188
+ Pixeltable column `image` of the table `tbl`:
189
189
 
190
190
  >>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
191
191
  """
@@ -228,24 +228,24 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
228
228
 
229
229
  Args:
230
230
  image: The image to embed.
231
- model_id: The pretrained model to use for the embedding.
231
+ model_id: The pretrained model to use for object detection.
232
232
 
233
233
  Returns:
234
234
  A dictionary containing the output of the object detection model, in the following format:
235
235
 
236
- ```python
237
- {
238
- 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
- 'labels': [25, 25], # list of COCO class labels for each detected object
240
- 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
- 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
- # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
- }
244
- ```
236
+ ```python
237
+ {
238
+ 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
+ 'labels': [25, 25], # list of COCO class labels for each detected object
240
+ 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
+ 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
+ # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
+ }
244
+ ```
245
245
 
246
246
  Examples:
247
247
  Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
248
- Pixeltable column `tbl.image` of the table `tbl`:
248
+ Pixeltable column `image` of the table `tbl`:
249
249
 
250
250
  >>> tbl['detections'] = detr_for_object_detection(
251
251
  ... tbl.image,
@@ -282,6 +282,83 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
282
282
  ]
283
283
 
284
284
 
285
+ @pxt.udf(batch_size=4)
286
+ def vit_for_image_classification(
287
+ image: Batch[PIL.Image.Image],
288
+ *,
289
+ model_id: str,
290
+ top_k: int = 5
291
+ ) -> Batch[list[dict[str, Any]]]:
292
+ """
293
+ Computes image classifications for the specified image using a Vision Transformer (ViT) model.
294
+ `model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
295
+
296
+ __Note:__ Be sure the model is a ViT model that is trained for image classification (that is, a model designed for
297
+ use with the
298
+ [ViTForImageClassification](https://huggingface.co/docs/transformers/en/model_doc/vit#transformers.ViTForImageClassification)
299
+ class), such as `google/vit-base-patch16-224`. General feature-extraction models such as
300
+ `google/vit-base-patch16-224-in21k` will not produce the desired results.
301
+
302
+ __Requirements:__
303
+
304
+ - `pip install transformers`
305
+
306
+ Args:
307
+ image: The image to classify.
308
+ model_id: The pretrained model to use for the classification.
309
+ top_k: The number of classes to return.
310
+
311
+ Returns:
312
+ A list of the `top_k` highest-scoring classes for each image. Each element in the list is a dictionary
313
+ in the following format:
314
+
315
+ ```python
316
+ {
317
+ 'p': 0.230, # class probability
318
+ 'class': 935, # class ID
319
+ 'label': 'mashed potato', # class label
320
+ }
321
+ ```
322
+
323
+ Examples:
324
+ Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
325
+ Pixeltable column `image` of the table `tbl`:
326
+
327
+ >>> tbl['image_class'] = vit_for_image_classification(
328
+ ... tbl.image,
329
+ ... model_id='google/vit-base-patch16-224'
330
+ ... )
331
+ """
332
+ env.Env.get().require_package('transformers')
333
+ device = resolve_torch_device('auto')
334
+ import torch
335
+ from transformers import ViTImageProcessor, ViTForImageClassification
336
+
337
+ model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
338
+ processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
339
+ normalized_images = [normalize_image_mode(img) for img in image]
340
+
341
+ with torch.no_grad():
342
+ inputs = processor(images=normalized_images, return_tensors='pt')
343
+ outputs = model(**inputs.to(device))
344
+ logits = outputs.logits
345
+
346
+ probs = torch.softmax(logits, dim=-1)
347
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
348
+
349
+ return [
350
+ [
351
+ {
352
+ 'p': top_k_probs[n, k].item(),
353
+ 'class': top_k_indices[n, k].item(),
354
+ 'label': model.config.id2label[top_k_indices[n, k].item()],
355
+ }
356
+ for k in range(top_k_probs.shape[1])
357
+ ]
358
+ for n in range(top_k_probs.shape[0])
359
+ ]
360
+
361
+
285
362
  @pxt.udf
286
363
  def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
287
364
  """
@@ -92,7 +92,7 @@ def _(self: Expr, mode: str) -> ts.ColumnType:
92
92
 
93
93
 
94
94
  # Image.crop()
95
- @func.udf(substitute_fn=PIL.Image.Image.crop, param_types=[ts.ImageType(), ts.ArrayType((4,), dtype=ts.IntType())], is_method=True)
95
+ @func.udf(substitute_fn=PIL.Image.Image.crop, is_method=True)
96
96
  def crop(self: PIL.Image.Image, box: tuple[int, int, int, int]) -> PIL.Image.Image:
97
97
  """
98
98
  Return a rectangular region from the image. The box is a 4-tuple defining the left, upper, right, and lower pixel
@@ -151,7 +151,7 @@ def _(self: Expr) -> ts.ColumnType:
151
151
 
152
152
 
153
153
  # Image.resize()
154
- @func.udf(param_types=[ts.ImageType(), ts.ArrayType((2,), dtype=ts.IntType())], is_method=True)
154
+ @func.udf(is_method=True)
155
155
  def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
156
156
  """
157
157
  Return a resized copy of the image. The size parameter is a tuple containing the width and height of the new image.
@@ -366,7 +366,7 @@ def quantize(
366
366
 
367
367
 
368
368
  @func.udf(substitute_fn=PIL.Image.Image.reduce, is_method=True)
369
- def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int]] = None) -> PIL.Image.Image:
369
+ def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int, int, int, int]] = None) -> PIL.Image.Image:
370
370
  """
371
371
  Reduce the image by the given factor.
372
372
 
@@ -7,12 +7,15 @@ the [Working with Together AI](https://pixeltable.readme.io/docs/together-ai) tu
7
7
 
8
8
  import base64
9
9
  import io
10
- from typing import TYPE_CHECKING, Optional
10
+ from typing import TYPE_CHECKING, Callable, Optional, TypeVar
11
11
 
12
12
  import numpy as np
13
13
  import PIL.Image
14
+ import requests
15
+ import tenacity
14
16
 
15
17
  import pixeltable as pxt
18
+ import pixeltable.exceptions as excs
16
19
  from pixeltable import env
17
20
  from pixeltable.func import Batch
18
21
  from pixeltable.utils.code import local_public_names
@@ -24,7 +27,6 @@ if TYPE_CHECKING:
24
27
  @env.register_client('together')
25
28
  def _(api_key: str) -> 'together.Together':
26
29
  import together
27
-
28
30
  return together.Together(api_key=api_key)
29
31
 
30
32
 
@@ -32,6 +34,18 @@ def _together_client() -> 'together.Together':
32
34
  return env.Env.get().get_client('together')
33
35
 
34
36
 
37
+ T = TypeVar('T')
38
+
39
+
40
+ def _retry(fn: Callable[..., T]) -> Callable[..., T]:
41
+ import together
42
+ return tenacity.retry(
43
+ retry=tenacity.retry_if_exception_type(together.error.RateLimitError),
44
+ wait=tenacity.wait_random_exponential(multiplier=1, max=60),
45
+ stop=tenacity.stop_after_attempt(20),
46
+ )(fn)
47
+
48
+
35
49
  @pxt.udf
36
50
  def completions(
37
51
  prompt: str,
@@ -74,8 +88,7 @@ def completions(
74
88
  >>> tbl['response'] = completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1')
75
89
  """
76
90
  return (
77
- _together_client()
78
- .completions.create(
91
+ _retry(_together_client().completions.create)(
79
92
  prompt=prompt,
80
93
  model=model,
81
94
  max_tokens=max_tokens,
@@ -139,8 +152,7 @@ def chat_completions(
139
152
  ... tbl['response'] = chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1')
140
153
  """
141
154
  return (
142
- _together_client()
143
- .chat.completions.create(
155
+ _retry(_together_client().chat.completions.create)(
144
156
  messages=messages,
145
157
  model=model,
146
158
  max_tokens=max_tokens,
@@ -198,7 +210,7 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
198
210
 
199
211
  >>> tbl['response'] = embeddings(tbl.text, model='togethercomputer/m2-bert-80M-8k-retrieval')
200
212
  """
201
- result = _together_client().embeddings.create(input=input, model=model)
213
+ result = _retry(_together_client().embeddings.create)(input=input, model=model)
202
214
  return [np.array(data.embedding, dtype=np.float64) for data in result.data]
203
215
 
204
216
 
@@ -242,20 +254,29 @@ def image_generations(
242
254
  The generated image.
243
255
 
244
256
  Examples:
245
- Add a computed column that applies the model `runwayml/stable-diffusion-v1-5`
257
+ Add a computed column that applies the model `stabilityai/stable-diffusion-xl-base-1.0`
246
258
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
247
259
 
248
- >>> tbl['response'] = image_generations(tbl.prompt, model='runwayml/stable-diffusion-v1-5')
260
+ >>> tbl['response'] = image_generations(tbl.prompt, model='stabilityai/stable-diffusion-xl-base-1.0')
249
261
  """
250
- # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
251
- result = _together_client().images.generate(
262
+ result = _retry(_together_client().images.generate)(
252
263
  prompt=prompt, model=model, steps=steps, seed=seed, height=height, width=width, negative_prompt=negative_prompt
253
264
  )
254
- b64_str = result.data[0].b64_json
255
- b64_bytes = base64.b64decode(b64_str)
256
- img = PIL.Image.open(io.BytesIO(b64_bytes))
257
- img.load()
258
- return img
265
+ if result.data[0].b64_json is not None:
266
+ b64_bytes = base64.b64decode(result.data[0].b64_json)
267
+ img = PIL.Image.open(io.BytesIO(b64_bytes))
268
+ img.load()
269
+ return img
270
+ if result.data[0].url is not None:
271
+ try:
272
+ resp = requests.get(result.data[0].url)
273
+ with io.BytesIO(resp.content) as fp:
274
+ image = PIL.Image.open(fp)
275
+ image.load()
276
+ return image
277
+ except Exception as exc:
278
+ raise excs.Error('Failed to download generated image from together.ai.') from exc
279
+ raise excs.Error('Response does not contain a generated image.')
259
280
 
260
281
 
261
282
  __all__ = local_public_names(__name__)