pixeltable 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (87) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +8 -7
  4. pixeltable/catalog/column.py +11 -8
  5. pixeltable/catalog/insertable_table.py +1 -1
  6. pixeltable/catalog/path_dict.py +8 -6
  7. pixeltable/catalog/table.py +20 -14
  8. pixeltable/catalog/table_version.py +92 -55
  9. pixeltable/catalog/table_version_path.py +7 -9
  10. pixeltable/catalog/view.py +3 -2
  11. pixeltable/dataframe.py +2 -2
  12. pixeltable/env.py +205 -86
  13. pixeltable/exceptions.py +5 -1
  14. pixeltable/exec/aggregation_node.py +2 -1
  15. pixeltable/exec/component_iteration_node.py +2 -2
  16. pixeltable/exec/sql_node.py +11 -8
  17. pixeltable/exprs/__init__.py +2 -2
  18. pixeltable/exprs/arithmetic_expr.py +4 -4
  19. pixeltable/exprs/array_slice.py +2 -1
  20. pixeltable/exprs/column_property_ref.py +9 -7
  21. pixeltable/exprs/column_ref.py +2 -1
  22. pixeltable/exprs/comparison.py +10 -7
  23. pixeltable/exprs/compound_predicate.py +3 -2
  24. pixeltable/exprs/data_row.py +19 -4
  25. pixeltable/exprs/expr.py +51 -41
  26. pixeltable/exprs/expr_set.py +32 -9
  27. pixeltable/exprs/function_call.py +62 -40
  28. pixeltable/exprs/in_predicate.py +3 -2
  29. pixeltable/exprs/inline_expr.py +200 -0
  30. pixeltable/exprs/is_null.py +3 -2
  31. pixeltable/exprs/json_mapper.py +5 -4
  32. pixeltable/exprs/json_path.py +7 -1
  33. pixeltable/exprs/literal.py +34 -7
  34. pixeltable/exprs/method_ref.py +3 -3
  35. pixeltable/exprs/object_ref.py +6 -5
  36. pixeltable/exprs/row_builder.py +25 -17
  37. pixeltable/exprs/rowid_ref.py +2 -1
  38. pixeltable/exprs/similarity_expr.py +2 -1
  39. pixeltable/exprs/sql_element_cache.py +30 -0
  40. pixeltable/exprs/type_cast.py +3 -3
  41. pixeltable/exprs/variable.py +2 -1
  42. pixeltable/ext/functions/whisperx.py +6 -4
  43. pixeltable/ext/functions/yolox.py +11 -9
  44. pixeltable/func/aggregate_function.py +1 -0
  45. pixeltable/func/function.py +28 -4
  46. pixeltable/functions/__init__.py +4 -2
  47. pixeltable/functions/anthropic.py +15 -5
  48. pixeltable/functions/fireworks.py +1 -1
  49. pixeltable/functions/globals.py +6 -1
  50. pixeltable/functions/huggingface.py +91 -14
  51. pixeltable/functions/image.py +20 -5
  52. pixeltable/functions/json.py +5 -5
  53. pixeltable/functions/mistralai.py +188 -0
  54. pixeltable/functions/openai.py +6 -10
  55. pixeltable/functions/string.py +3 -2
  56. pixeltable/functions/timestamp.py +95 -7
  57. pixeltable/functions/together.py +18 -11
  58. pixeltable/functions/video.py +2 -2
  59. pixeltable/functions/vision.py +69 -37
  60. pixeltable/functions/whisper.py +4 -1
  61. pixeltable/globals.py +5 -1
  62. pixeltable/io/hf_datasets.py +17 -15
  63. pixeltable/io/pandas.py +0 -2
  64. pixeltable/io/parquet.py +15 -14
  65. pixeltable/iterators/document.py +16 -15
  66. pixeltable/metadata/__init__.py +1 -1
  67. pixeltable/metadata/converters/convert_18.py +1 -1
  68. pixeltable/metadata/converters/convert_19.py +46 -0
  69. pixeltable/metadata/converters/convert_20.py +56 -0
  70. pixeltable/metadata/converters/util.py +29 -4
  71. pixeltable/metadata/notes.py +2 -0
  72. pixeltable/metadata/schema.py +5 -4
  73. pixeltable/plan.py +100 -78
  74. pixeltable/store.py +5 -1
  75. pixeltable/tool/create_test_db_dump.py +18 -6
  76. pixeltable/type_system.py +15 -15
  77. pixeltable/utils/documents.py +45 -42
  78. pixeltable/utils/formatter.py +2 -2
  79. pixeltable-0.2.19.dist-info/LICENSE +201 -0
  80. {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/METADATA +84 -24
  81. pixeltable-0.2.19.dist-info/RECORD +147 -0
  82. pixeltable/exprs/inline_array.py +0 -116
  83. pixeltable/exprs/inline_dict.py +0 -103
  84. pixeltable-0.2.17.dist-info/LICENSE +0 -18
  85. pixeltable-0.2.17.dist-info/RECORD +0 -144
  86. {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/WHEEL +0 -0
  87. {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,10 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Iterable, Iterator, TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Iterable, Iterator
4
4
  from urllib.request import urlretrieve
5
5
 
6
- import PIL.Image
7
6
  import numpy as np
7
+ import PIL.Image
8
8
 
9
9
  import pixeltable as pxt
10
10
  from pixeltable import env
@@ -14,8 +14,8 @@ from pixeltable.utils.code import local_public_names
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  import torch
17
- from yolox.exp import Exp
18
- from yolox.models import YOLOX
17
+ from yolox.exp import Exp # type: ignore[import-untyped]
18
+ from yolox.models import YOLOX # type: ignore[import-untyped]
19
19
 
20
20
  _logger = logging.getLogger('pixeltable')
21
21
 
@@ -26,8 +26,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
26
26
  Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
27
27
  defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
28
28
 
29
- YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
30
- intended for use in production applications.
29
+ YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
31
30
 
32
31
  __Requirements__:
33
32
 
@@ -47,7 +46,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
47
46
  >>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
48
47
  """
49
48
  import torch
50
- from yolox.utils import postprocess
49
+ from yolox.utils import postprocess # type: ignore[import-untyped]
51
50
 
52
51
  model, exp = _lookup_model(model_id, 'cpu')
53
52
  image_tensors = list(_images_to_tensors(images, exp))
@@ -79,6 +78,8 @@ def yolo_to_coco(detections: dict) -> list:
79
78
  """
80
79
  Converts the output of a YOLOX object detection model to COCO format.
81
80
 
81
+ YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
82
+
82
83
  Args:
83
84
  detections: The output of a YOLOX object detection model, as returned by `yolox`.
84
85
 
@@ -89,7 +90,8 @@ def yolo_to_coco(detections: dict) -> list:
89
90
  Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
90
91
  is the image for which detections were computed:
91
92
 
92
- >>> tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
+ >>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
94
+ ... tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
95
  """
94
96
  bboxes, labels = detections['bboxes'], detections['labels']
95
97
  num_annotations = len(detections['bboxes'])
@@ -107,7 +109,7 @@ def yolo_to_coco(detections: dict) -> list:
107
109
 
108
110
  def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: 'Exp') -> Iterator['torch.Tensor']:
109
111
  import torch
110
- from yolox.data import ValTransform
112
+ from yolox.data import ValTransform # type: ignore[import-untyped]
111
113
 
112
114
  _val_transform = ValTransform(legacy=False)
113
115
  for image in images:
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
18
18
  class Aggregator(abc.ABC):
19
19
  def update(self, *args: Any, **kwargs: Any) -> None:
20
20
  pass
21
+
21
22
  def value(self) -> Any:
22
23
  pass
23
24
 
@@ -5,10 +5,10 @@ import importlib
5
5
  import inspect
6
6
  from typing import Any, Callable, Dict, Optional, Tuple
7
7
 
8
+ import sqlalchemy as sql
9
+
8
10
  import pixeltable
9
- import pixeltable.exceptions as excs
10
11
  import pixeltable.type_system as ts
11
-
12
12
  from .globals import resolve_symbol
13
13
  from .signature import Signature
14
14
 
@@ -21,14 +21,29 @@ class Function(abc.ABC):
21
21
  via the member self_path.
22
22
  """
23
23
 
24
- def __init__(self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False):
24
+ signature: Signature
25
+ self_path: Optional[str]
26
+ is_method: bool
27
+ is_property: bool
28
+ _conditional_return_type: Optional[Callable[..., ts.ColumnType]]
29
+
30
+ # Translates a call to this function with the given arguments to its SQLAlchemy equivalent.
31
+ # Overriden for specific Function instances via the to_sql() decorator. The override must accept the same
32
+ # parameter names as the original function. Each parameter is going to be of type sql.ColumnElement.
33
+ _to_sql: Callable[..., Optional[sql.ColumnElement]]
34
+
35
+
36
+ def __init__(
37
+ self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False
38
+ ):
25
39
  # Check that stored functions cannot be declared using `is_method` or `is_property`:
26
40
  assert not ((is_method or is_property) and self_path is None)
27
41
  self.signature = signature
28
42
  self.self_path = self_path # fully-qualified path to self
29
43
  self.is_method = is_method
30
44
  self.is_property = is_property
31
- self._conditional_return_type: Optional[Callable[..., ts.ColumnType]] = None
45
+ self._conditional_return_type = None
46
+ self._to_sql = self.__default_to_sql
32
47
 
33
48
  @property
34
49
  def name(self) -> str:
@@ -88,6 +103,15 @@ class Function(abc.ABC):
88
103
  """Execute the function with the given arguments and return the result."""
89
104
  pass
90
105
 
106
+ def to_sql(self, fn: Callable[..., Optional[sql.ColumnElement]]) -> Callable[..., Optional[sql.ColumnElement]]:
107
+ """Instance decorator for specifying the SQL translation of this function"""
108
+ self._to_sql = fn
109
+ return fn
110
+
111
+ def __default_to_sql(self, *args: Any, **kwargs: Any) -> Optional[sql.ColumnElement]:
112
+ """The default implementation of SQL translation, which provides no translation"""
113
+ return None
114
+
91
115
  def __eq__(self, other: object) -> bool:
92
116
  if not isinstance(other, self.__class__):
93
117
  return False
@@ -1,7 +1,9 @@
1
- from . import anthropic, audio, fireworks, huggingface, image, json, openai, string, timestamp, together, video, vision
2
- from .globals import *
3
1
  from pixeltable.utils.code import local_public_names
4
2
 
3
+ from . import (anthropic, audio, fireworks, huggingface, image, json, mistralai, openai, string, timestamp, together,
4
+ video, vision)
5
+ from .globals import *
6
+
5
7
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
6
8
 
7
9
 
@@ -5,7 +5,9 @@ first `pip install anthropic` and configure your Anthropic credentials, as descr
5
5
  the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anthropic) tutorial.
6
6
  """
7
7
 
8
- from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
8
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
9
+
10
+ import tenacity
9
11
 
10
12
  import pixeltable as pxt
11
13
  from pixeltable import env
@@ -18,7 +20,6 @@ if TYPE_CHECKING:
18
20
  @env.register_client('anthropic')
19
21
  def _(api_key: str) -> 'anthropic.Anthropic':
20
22
  import anthropic
21
-
22
23
  return anthropic.Anthropic(api_key=api_key)
23
24
 
24
25
 
@@ -26,6 +27,15 @@ def _anthropic_client() -> 'anthropic.Anthropic':
26
27
  return env.Env.get().get_client('anthropic')
27
28
 
28
29
 
30
+ def _retry(fn: Callable) -> Callable:
31
+ import anthropic
32
+ return tenacity.retry(
33
+ retry=tenacity.retry_if_exception_type(anthropic.RateLimitError),
34
+ wait=tenacity.wait_random_exponential(multiplier=1, max=60),
35
+ stop=tenacity.stop_after_attempt(20),
36
+ )(fn)
37
+
38
+
29
39
  @pxt.udf
30
40
  def messages(
31
41
  messages: list[dict[str, str]],
@@ -67,7 +77,7 @@ def messages(
67
77
  >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
68
78
  ... tbl['response'] = messages(msgs, model='claude-3-haiku-20240307')
69
79
  """
70
- return _anthropic_client().messages.create(
80
+ return _retry(_anthropic_client().messages.create)(
71
81
  messages=messages,
72
82
  model=model,
73
83
  max_tokens=max_tokens,
@@ -86,8 +96,8 @@ _T = TypeVar('_T')
86
96
 
87
97
 
88
98
  def _opt(arg: _T) -> Union[_T, 'anthropic.NotGiven']:
89
- from anthropic import NOT_GIVEN
90
- return arg if arg is not None else NOT_GIVEN
99
+ import anthropic
100
+ return arg if arg is not None else anthropic.NOT_GIVEN
91
101
 
92
102
 
93
103
  __all__ = local_public_names(__name__)
@@ -12,7 +12,7 @@ from pixeltable import env
12
12
  from pixeltable.utils.code import local_public_names
13
13
 
14
14
  if TYPE_CHECKING:
15
- import fireworks.client
15
+ import fireworks.client # type: ignore[import-untyped]
16
16
 
17
17
 
18
18
  @env.register_client('fireworks')
@@ -1,4 +1,4 @@
1
- from typing import Optional, Union
1
+ from typing import Optional, Union, Any
2
2
 
3
3
  import pixeltable.func as func
4
4
  import pixeltable.type_system as ts
@@ -25,6 +25,11 @@ class sum(func.Aggregator):
25
25
  def value(self) -> Union[int, float]:
26
26
  return self.sum
27
27
 
28
+ # @sum.to_sql
29
+ # def _(val: 'sqlalchemy.ColumnElements') -> Optional['sqlalchemy.ColumnElements']:
30
+ # import sqlalchemy as sql
31
+ # return sql.sql.functions.sum(val)
32
+
28
33
 
29
34
  @func.uda(update_types=[ts.IntType()], value_type=ts.IntType(), allows_window=True, requires_order_by=False)
30
35
  class count(func.Aggregator):
@@ -185,7 +185,7 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndar
185
185
 
186
186
  Examples:
187
187
  Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
188
- Pixeltable column `tbl.image` of the table `tbl`:
188
+ Pixeltable column `image` of the table `tbl`:
189
189
 
190
190
  >>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
191
191
  """
@@ -228,24 +228,24 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
228
228
 
229
229
  Args:
230
230
  image: The image to embed.
231
- model_id: The pretrained model to use for the embedding.
231
+ model_id: The pretrained model to use for object detection.
232
232
 
233
233
  Returns:
234
234
  A dictionary containing the output of the object detection model, in the following format:
235
235
 
236
- ```python
237
- {
238
- 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
- 'labels': [25, 25], # list of COCO class labels for each detected object
240
- 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
- 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
- # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
- }
244
- ```
236
+ ```python
237
+ {
238
+ 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
+ 'labels': [25, 25], # list of COCO class labels for each detected object
240
+ 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
+ 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
+ # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
+ }
244
+ ```
245
245
 
246
246
  Examples:
247
247
  Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
248
- Pixeltable column `tbl.image` of the table `tbl`:
248
+ Pixeltable column `image` of the table `tbl`:
249
249
 
250
250
  >>> tbl['detections'] = detr_for_object_detection(
251
251
  ... tbl.image,
@@ -282,6 +282,83 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
282
282
  ]
283
283
 
284
284
 
285
+ @pxt.udf(batch_size=4)
286
+ def vit_for_image_classification(
287
+ image: Batch[PIL.Image.Image],
288
+ *,
289
+ model_id: str,
290
+ top_k: int = 5
291
+ ) -> Batch[list[dict[str, Any]]]:
292
+ """
293
+ Computes image classifications for the specified image using a Vision Transformer (ViT) model.
294
+ `model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
295
+
296
+ __Note:__ Be sure the model is a ViT model that is trained for image classification (that is, a model designed for
297
+ use with the
298
+ [ViTForImageClassification](https://huggingface.co/docs/transformers/en/model_doc/vit#transformers.ViTForImageClassification)
299
+ class), such as `google/vit-base-patch16-224`. General feature-extraction models such as
300
+ `google/vit-base-patch16-224-in21k` will not produce the desired results.
301
+
302
+ __Requirements:__
303
+
304
+ - `pip install transformers`
305
+
306
+ Args:
307
+ image: The image to classify.
308
+ model_id: The pretrained model to use for the classification.
309
+ top_k: The number of classes to return.
310
+
311
+ Returns:
312
+ A list of the `top_k` highest-scoring classes for each image. Each element in the list is a dictionary
313
+ in the following format:
314
+
315
+ ```python
316
+ {
317
+ 'p': 0.230, # class probability
318
+ 'class': 935, # class ID
319
+ 'label': 'mashed potato', # class label
320
+ }
321
+ ```
322
+
323
+ Examples:
324
+ Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
325
+ Pixeltable column `image` of the table `tbl`:
326
+
327
+ >>> tbl['image_class'] = vit_for_image_classification(
328
+ ... tbl.image,
329
+ ... model_id='google/vit-base-patch16-224'
330
+ ... )
331
+ """
332
+ env.Env.get().require_package('transformers')
333
+ device = resolve_torch_device('auto')
334
+ import torch
335
+ from transformers import ViTImageProcessor, ViTForImageClassification
336
+
337
+ model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
338
+ processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
339
+ normalized_images = [normalize_image_mode(img) for img in image]
340
+
341
+ with torch.no_grad():
342
+ inputs = processor(images=normalized_images, return_tensors='pt')
343
+ outputs = model(**inputs.to(device))
344
+ logits = outputs.logits
345
+
346
+ probs = torch.softmax(logits, dim=-1)
347
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
348
+
349
+ return [
350
+ [
351
+ {
352
+ 'p': top_k_probs[n, k].item(),
353
+ 'class': top_k_indices[n, k].item(),
354
+ 'label': model.config.id2label[top_k_indices[n, k].item()],
355
+ }
356
+ for k in range(top_k_probs.shape[1])
357
+ ]
358
+ for n in range(top_k_probs.shape[0])
359
+ ]
360
+
361
+
285
362
  @pxt.udf
286
363
  def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
287
364
  """
@@ -332,8 +409,8 @@ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
332
409
  return _processor_cache[key]
333
410
 
334
411
 
335
- _model_cache = {}
336
- _processor_cache = {}
412
+ _model_cache: dict[tuple[str, Callable, Optional[str]], Any] = {}
413
+ _processor_cache: dict[tuple[str, Callable], Any] = {}
337
414
 
338
415
 
339
416
  __all__ = local_public_names(__name__)
@@ -92,7 +92,7 @@ def _(self: Expr, mode: str) -> ts.ColumnType:
92
92
 
93
93
 
94
94
  # Image.crop()
95
- @func.udf(substitute_fn=PIL.Image.Image.crop, param_types=[ts.ImageType(), ts.ArrayType((4,), dtype=ts.IntType())], is_method=True)
95
+ @func.udf(substitute_fn=PIL.Image.Image.crop, is_method=True)
96
96
  def crop(self: PIL.Image.Image, box: tuple[int, int, int, int]) -> PIL.Image.Image:
97
97
  """
98
98
  Return a rectangular region from the image. The box is a 4-tuple defining the left, upper, right, and lower pixel
@@ -128,6 +128,21 @@ def getchannel(self: PIL.Image.Image, channel: int) -> PIL.Image.Image:
128
128
  pass
129
129
 
130
130
 
131
+ @func.udf(is_method=True)
132
+ def get_metadata(self: PIL.Image.Image) -> dict:
133
+ """
134
+ Return metadata for the image.
135
+ """
136
+ return {
137
+ 'width': self.width,
138
+ 'height': self.height,
139
+ 'mode': self.mode,
140
+ 'bits': getattr(self, 'bits', None),
141
+ 'format': self.format,
142
+ 'palette': self.palette,
143
+ }
144
+
145
+
131
146
  @getchannel.conditional_return_type
132
147
  def _(self: Expr) -> ts.ColumnType:
133
148
  input_type = self.col_type
@@ -136,7 +151,7 @@ def _(self: Expr) -> ts.ColumnType:
136
151
 
137
152
 
138
153
  # Image.resize()
139
- @func.udf(param_types=[ts.ImageType(), ts.ArrayType((2,), dtype=ts.IntType())], is_method=True)
154
+ @func.udf(is_method=True)
140
155
  def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
141
156
  """
142
157
  Return a resized copy of the image. The size parameter is a tuple containing the width and height of the new image.
@@ -144,7 +159,7 @@ def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
144
159
  Equivalent to
145
160
  [`PIL.Image.Image.resize()`](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.resize)
146
161
  """
147
- return self.resize(tuple(size))
162
+ return self.resize(tuple(size)) # type: ignore[arg-type]
148
163
 
149
164
 
150
165
  @resize.conditional_return_type
@@ -294,7 +309,7 @@ def getpixel(self: PIL.Image.Image, xy: tuple[int, int]) -> tuple[int]:
294
309
  xy: The coordinates, given as (x, y).
295
310
  """
296
311
  # `xy` will be a list; `tuple(xy)` is necessary for pillow 9 compatibility
297
- return self.getpixel(tuple(xy))
312
+ return self.getpixel(tuple(xy)) # type: ignore[arg-type]
298
313
 
299
314
 
300
315
  @func.udf(substitute_fn=PIL.Image.Image.getprojection, is_method=True)
@@ -351,7 +366,7 @@ def quantize(
351
366
 
352
367
 
353
368
  @func.udf(substitute_fn=PIL.Image.Image.reduce, is_method=True)
354
- def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int]] = None) -> PIL.Image.Image:
369
+ def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int, int, int, int]] = None) -> PIL.Image.Image:
355
370
  """
356
371
  Reduce the image by the given factor.
357
372
 
@@ -12,18 +12,18 @@ t.select(pxt.functions.json.make_list()).collect()
12
12
 
13
13
  from typing import Any
14
14
 
15
- import pixeltable.func as func
15
+ import pixeltable as pxt
16
16
  import pixeltable.type_system as ts
17
17
  from pixeltable.utils.code import local_public_names
18
18
 
19
19
 
20
- @func.uda(
21
- update_types=[ts.JsonType(nullable=True)],
22
- value_type=ts.JsonType(),
20
+ @pxt.uda(
21
+ update_types=[pxt.JsonType(nullable=True)],
22
+ value_type=pxt.JsonType(),
23
23
  requires_order_by=False,
24
24
  allows_window=False,
25
25
  )
26
- class make_list(func.Aggregator):
26
+ class make_list(pxt.Aggregator):
27
27
  """
28
28
  Collects arguments into a list.
29
29
  """
@@ -0,0 +1,188 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Mistral AI API. In order to use them, you must
4
+ first `pip install mistralai` and configure your Mistral AI credentials, as described in
5
+ the [Working with Mistral AI](https://pixeltable.readme.io/docs/working-with-mistralai) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Optional, TypeVar, Union
9
+
10
+ import numpy as np
11
+
12
+ import pixeltable as pxt
13
+ from pixeltable.env import Env, register_client
14
+ from pixeltable.func.signature import Batch
15
+ from pixeltable.utils.code import local_public_names
16
+
17
+ if TYPE_CHECKING:
18
+ import mistralai.types.basemodel
19
+
20
+
21
+ @register_client('mistral')
22
+ def _(api_key: str) -> 'mistralai.Mistral':
23
+ import mistralai
24
+ return mistralai.Mistral(api_key=api_key)
25
+
26
+
27
+ def _mistralai_client() -> 'mistralai.Mistral':
28
+ return Env.get().get_client('mistral')
29
+
30
+
31
+ @pxt.udf
32
+ def chat_completions(
33
+ messages: list[dict[str, str]],
34
+ *,
35
+ model: str,
36
+ temperature: Optional[float] = 0.7,
37
+ top_p: Optional[float] = 1.0,
38
+ max_tokens: Optional[int] = None,
39
+ min_tokens: Optional[int] = None,
40
+ stop: Optional[list[str]] = None,
41
+ random_seed: Optional[int] = None,
42
+ response_format: Optional[dict] = None,
43
+ safe_prompt: Optional[bool] = False,
44
+ ) -> dict:
45
+ """
46
+ Chat Completion API.
47
+
48
+ Equivalent to the Mistral AI `chat/completions` API endpoint.
49
+ For additional details, see: <https://docs.mistral.ai/api/#tag/chat>
50
+
51
+ __Requirements:__
52
+
53
+ - `pip install mistralai`
54
+
55
+ Args:
56
+ messages: The prompt(s) to generate completions for.
57
+ model: ID of the model to use. (See overview here: <https://docs.mistral.ai/getting-started/models/>)
58
+
59
+ For details on the other parameters, see: <https://docs.mistral.ai/api/#tag/chat>
60
+
61
+ Returns:
62
+ A dictionary containing the response and other metadata.
63
+
64
+ Examples:
65
+ Add a computed column that applies the model `mistral-latest-small`
66
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
67
+
68
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
69
+ ... tbl['response'] = completions(messages, model='mistral-latest-small')
70
+ """
71
+ Env.get().require_package('mistralai')
72
+ return _mistralai_client().chat.complete(
73
+ messages=messages, # type: ignore[arg-type]
74
+ model=model,
75
+ temperature=temperature,
76
+ top_p=top_p,
77
+ max_tokens=_opt(max_tokens),
78
+ min_tokens=_opt(min_tokens),
79
+ stop=stop,
80
+ random_seed=_opt(random_seed),
81
+ response_format=response_format, # type: ignore[arg-type]
82
+ safe_prompt=safe_prompt,
83
+ ).dict()
84
+
85
+
86
+ @pxt.udf
87
+ def fim_completions(
88
+ prompt: str,
89
+ *,
90
+ model: str,
91
+ temperature: Optional[float] = 0.7,
92
+ top_p: Optional[float] = 1.0,
93
+ max_tokens: Optional[int] = None,
94
+ min_tokens: Optional[int] = None,
95
+ stop: Optional[list[str]] = None,
96
+ random_seed: Optional[int] = None,
97
+ suffix: Optional[str] = None,
98
+ ) -> dict:
99
+ """
100
+ Fill-in-the-middle Completion API.
101
+
102
+ Equivalent to the Mistral AI `fim/completions` API endpoint.
103
+ For additional details, see: <https://docs.mistral.ai/api/#tag/fim>
104
+
105
+ __Requirements:__
106
+
107
+ - `pip install mistralai`
108
+
109
+ Args:
110
+ prompt: The text/code to complete.
111
+ model: ID of the model to use. (See overview here: <https://docs.mistral.ai/getting-started/models/>)
112
+
113
+ For details on the other parameters, see: <https://docs.mistral.ai/api/#tag/fim>
114
+
115
+ Returns:
116
+ A dictionary containing the response and other metadata.
117
+
118
+ Examples:
119
+ Add a computed column that applies the model `codestral-latest`
120
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
121
+
122
+ >>> tbl['response'] = completions(tbl.prompt, model='codestral-latest')
123
+ """
124
+ Env.get().require_package('mistralai')
125
+ return _mistralai_client().fim.complete(
126
+ prompt=prompt,
127
+ model=model,
128
+ temperature=temperature,
129
+ top_p=top_p,
130
+ max_tokens=_opt(max_tokens),
131
+ min_tokens=_opt(min_tokens),
132
+ stop=stop,
133
+ random_seed=_opt(random_seed),
134
+ suffix=_opt(suffix)
135
+ ).dict()
136
+
137
+
138
+ _embedding_dimensions_cache: dict[str, int] = {
139
+ 'mistral-embed': 1024
140
+ }
141
+
142
+
143
+ @pxt.udf(batch_size=16, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
144
+ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
145
+ """
146
+ Embeddings API.
147
+
148
+ Equivalent to the Mistral AI `embeddings` API endpoint.
149
+ For additional details, see: <https://docs.mistral.ai/api/#tag/embeddings>
150
+
151
+ __Requirements:__
152
+
153
+ - `pip install mistralai`
154
+
155
+ Args:
156
+ input: Text to embed.
157
+ model: ID of the model to use. (See overview here: <https://docs.mistral.ai/getting-started/models/>)
158
+
159
+ Returns:
160
+ An array representing the application of the given embedding to `input`.
161
+ """
162
+ Env.get().require_package('mistralai')
163
+ result = _mistralai_client().embeddings.create(
164
+ inputs=input,
165
+ model=model,
166
+ )
167
+ return [np.array(data.embedding, dtype=np.float64) for data in result.data]
168
+
169
+
170
+ @embeddings.conditional_return_type
171
+ def _(model: str) -> pxt.ArrayType:
172
+ dimensions = _embedding_dimensions_cache.get(model) # `None` if unknown model
173
+ return pxt.ArrayType((dimensions,), dtype=pxt.FloatType())
174
+
175
+
176
+ _T = TypeVar('_T')
177
+
178
+
179
+ def _opt(arg: Optional[_T]) -> Union[_T, 'mistralai.types.basemodel.Unset']:
180
+ from mistralai.types import UNSET
181
+ return arg if arg is not None else UNSET
182
+
183
+
184
+ __all__ = local_public_names(__name__)
185
+
186
+
187
+ def __dir__():
188
+ return __all__