pixeltable 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (87) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +9 -5
  4. pixeltable/catalog/insertable_table.py +0 -2
  5. pixeltable/catalog/table.py +16 -8
  6. pixeltable/catalog/table_version.py +3 -2
  7. pixeltable/dataframe.py +184 -110
  8. pixeltable/env.py +69 -18
  9. pixeltable/exec/__init__.py +2 -1
  10. pixeltable/exec/data_row_batch.py +6 -7
  11. pixeltable/exec/expr_eval_node.py +28 -28
  12. pixeltable/exec/sql_scan_node.py +7 -6
  13. pixeltable/exprs/__init__.py +4 -3
  14. pixeltable/exprs/column_ref.py +9 -0
  15. pixeltable/exprs/expr.py +15 -7
  16. pixeltable/exprs/function_call.py +17 -15
  17. pixeltable/exprs/image_member_access.py +9 -28
  18. pixeltable/exprs/in_predicate.py +96 -0
  19. pixeltable/exprs/inline_array.py +13 -11
  20. pixeltable/exprs/inline_dict.py +15 -13
  21. pixeltable/exprs/row_builder.py +7 -1
  22. pixeltable/exprs/similarity_expr.py +65 -0
  23. pixeltable/func/__init__.py +0 -2
  24. pixeltable/func/aggregate_function.py +3 -0
  25. pixeltable/func/callable_function.py +57 -13
  26. pixeltable/func/expr_template_function.py +11 -2
  27. pixeltable/func/function.py +35 -4
  28. pixeltable/func/signature.py +5 -15
  29. pixeltable/func/udf.py +6 -10
  30. pixeltable/functions/huggingface.py +23 -4
  31. pixeltable/functions/openai.py +34 -1
  32. pixeltable/functions/pil/image.py +61 -64
  33. pixeltable/functions/together.py +21 -0
  34. pixeltable/globals.py +425 -0
  35. pixeltable/index/base.py +3 -1
  36. pixeltable/index/embedding_index.py +87 -14
  37. pixeltable/io/__init__.py +3 -0
  38. pixeltable/{utils → io}/hf_datasets.py +48 -17
  39. pixeltable/io/pandas.py +148 -0
  40. pixeltable/{utils → io}/parquet.py +58 -33
  41. pixeltable/iterators/__init__.py +1 -1
  42. pixeltable/iterators/base.py +4 -0
  43. pixeltable/iterators/document.py +218 -97
  44. pixeltable/iterators/video.py +8 -9
  45. pixeltable/metadata/__init__.py +7 -3
  46. pixeltable/metadata/converters/convert_12.py +3 -0
  47. pixeltable/metadata/converters/convert_13.py +41 -0
  48. pixeltable/plan.py +2 -19
  49. pixeltable/store.py +2 -2
  50. pixeltable/tool/create_test_db_dump.py +32 -13
  51. pixeltable/type_system.py +13 -54
  52. pixeltable/utils/documents.py +42 -12
  53. pixeltable/utils/http_server.py +70 -0
  54. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/METADATA +10 -7
  55. pixeltable-0.2.6.dist-info/RECORD +119 -0
  56. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  57. pixeltable/client.py +0 -600
  58. pixeltable/exprs/image_similarity_predicate.py +0 -58
  59. pixeltable/func/batched_function.py +0 -53
  60. pixeltable/tests/conftest.py +0 -171
  61. pixeltable/tests/ext/test_yolox.py +0 -21
  62. pixeltable/tests/functions/test_fireworks.py +0 -43
  63. pixeltable/tests/functions/test_functions.py +0 -60
  64. pixeltable/tests/functions/test_huggingface.py +0 -158
  65. pixeltable/tests/functions/test_openai.py +0 -162
  66. pixeltable/tests/functions/test_together.py +0 -112
  67. pixeltable/tests/test_audio.py +0 -65
  68. pixeltable/tests/test_catalog.py +0 -27
  69. pixeltable/tests/test_client.py +0 -21
  70. pixeltable/tests/test_component_view.py +0 -379
  71. pixeltable/tests/test_dataframe.py +0 -440
  72. pixeltable/tests/test_dirs.py +0 -107
  73. pixeltable/tests/test_document.py +0 -120
  74. pixeltable/tests/test_exprs.py +0 -802
  75. pixeltable/tests/test_function.py +0 -332
  76. pixeltable/tests/test_index.py +0 -138
  77. pixeltable/tests/test_migration.py +0 -44
  78. pixeltable/tests/test_nos.py +0 -54
  79. pixeltable/tests/test_snapshot.py +0 -231
  80. pixeltable/tests/test_table.py +0 -1343
  81. pixeltable/tests/test_transactional_directory.py +0 -42
  82. pixeltable/tests/test_types.py +0 -52
  83. pixeltable/tests/test_video.py +0 -159
  84. pixeltable/tests/test_view.py +0 -535
  85. pixeltable/tests/utils.py +0 -442
  86. pixeltable-0.2.5.dist-info/RECORD +0 -139
  87. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
pixeltable/func/udf.py CHANGED
@@ -6,7 +6,6 @@ from typing import List, Callable, Optional, overload, Any
6
6
  import pixeltable as pxt
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
9
- from .batched_function import ExplicitBatchedFunction
10
9
  from .callable_function import CallableFunction
11
10
  from .expr_template_function import ExprTemplateFunction
12
11
  from .function import Function
@@ -62,8 +61,8 @@ def udf(*args, **kwargs):
62
61
 
63
62
  def decorator(decorated_fn: Callable):
64
63
  return make_function(
65
- decorated_fn, return_type, param_types, batch_size, substitute_fn=substitute_fn,
66
- force_stored=force_stored)
64
+ decorated_fn, return_type, param_types, batch_size,
65
+ substitute_fn=substitute_fn, force_stored=force_stored)
67
66
 
68
67
  return decorator
69
68
 
@@ -78,8 +77,8 @@ def make_function(
78
77
  force_stored: bool = False
79
78
  ) -> Function:
80
79
  """
81
- Constructs a `CallableFunction` or `BatchedFunction`, depending on the
82
- supplied parameters. If `substitute_fn` is specified, then `decorated_fn`
80
+ Constructs a `CallableFunction` from the specified parameters.
81
+ If `substitute_fn` is specified, then `decorated_fn`
83
82
  will be used only for its signature, with execution delegated to
84
83
  `substitute_fn`.
85
84
  """
@@ -117,11 +116,8 @@ def make_function(
117
116
  raise excs.Error(f'{errmsg_name}(): @udf decorator with a `substitute_fn` can only be used in a module')
118
117
  py_fn = substitute_fn
119
118
 
120
- if batch_size is None:
121
- result = CallableFunction(signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name)
122
- else:
123
- result = ExplicitBatchedFunction(
124
- signature=sig, batch_size=batch_size, invoker_fn=py_fn, self_path=function_path)
119
+ result = CallableFunction(
120
+ signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name, batch_size=batch_size)
125
121
 
126
122
  # If this function is part of a module, register it
127
123
  if function_path is not None:
@@ -23,6 +23,16 @@ def sentence_transformer(
23
23
  return [array[i] for i in range(array.shape[0])]
24
24
 
25
25
 
26
+ @sentence_transformer.conditional_return_type
27
+ def _(model_id: str) -> ts.ArrayType:
28
+ try:
29
+ from sentence_transformers import SentenceTransformer
30
+ model = _lookup_model(model_id, SentenceTransformer)
31
+ return ts.ArrayType((model.get_sentence_embedding_dimension(),), dtype=ts.FloatType(), nullable=False)
32
+ except ImportError:
33
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
34
+
35
+
26
36
  @pxt.udf
27
37
  def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
28
38
  env.Env.get().require_package('sentence_transformers')
@@ -56,7 +66,7 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
56
66
  return array.tolist()
57
67
 
58
68
 
59
- @pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
69
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
60
70
  def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
61
71
  env.Env.get().require_package('transformers')
62
72
  device = resolve_torch_device('auto')
@@ -64,7 +74,6 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
64
74
  from transformers import CLIPModel, CLIPProcessor
65
75
 
66
76
  model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
67
- assert model.config.projection_dim == 512
68
77
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
69
78
 
70
79
  with torch.no_grad():
@@ -74,7 +83,7 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
74
83
  return [embeddings[i] for i in range(embeddings.shape[0])]
75
84
 
76
85
 
77
- @pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
86
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
78
87
  def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
79
88
  env.Env.get().require_package('transformers')
80
89
  device = resolve_torch_device('auto')
@@ -82,7 +91,6 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndar
82
91
  from transformers import CLIPModel, CLIPProcessor
83
92
 
84
93
  model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
85
- assert model.config.projection_dim == 512
86
94
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
87
95
 
88
96
  with torch.no_grad():
@@ -92,6 +100,17 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndar
92
100
  return [embeddings[i] for i in range(embeddings.shape[0])]
93
101
 
94
102
 
103
+ @clip_text.conditional_return_type
104
+ @clip_image.conditional_return_type
105
+ def _(model_id: str) -> ts.ArrayType:
106
+ try:
107
+ from transformers import CLIPModel
108
+ model = _lookup_model(model_id, CLIPModel.from_pretrained)
109
+ return ts.ArrayType((model.config.projection_dim,), dtype=ts.FloatType(), nullable=False)
110
+ except ImportError:
111
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
112
+
113
+
95
114
  @pxt.udf(batch_size=4)
96
115
  def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
97
116
  env.Env.get().require_package('transformers')
@@ -53,7 +53,7 @@ def speech(
53
53
  )
54
54
  ext = response_format or 'mp3'
55
55
  output_filename = str(env.Env.get().tmp_dir / f"{uuid.uuid4()}.{ext}")
56
- content.stream_to_file(output_filename, chunk_size=1 << 20)
56
+ content.write_to_file(output_filename)
57
57
  return output_filename
58
58
 
59
59
 
@@ -181,17 +181,26 @@ def vision(
181
181
  #####################################
182
182
  # Embeddings Endpoints
183
183
 
184
+ _embedding_dimensions_cache: dict[str, int] = {
185
+ 'text-embedding-ada-002': 1536,
186
+ 'text-embedding-3-small': 1536,
187
+ 'text-embedding-3-large': 3072,
188
+ }
189
+
190
+
184
191
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
185
192
  @_retry
186
193
  def embeddings(
187
194
  input: Batch[str],
188
195
  *,
189
196
  model: str,
197
+ dimensions: Optional[int] = None,
190
198
  user: Optional[str] = None
191
199
  ) -> Batch[np.ndarray]:
192
200
  result = openai_client().embeddings.create(
193
201
  input=input,
194
202
  model=model,
203
+ dimensions=_opt(dimensions),
195
204
  user=_opt(user),
196
205
  encoding_format='float'
197
206
  )
@@ -201,6 +210,16 @@ def embeddings(
201
210
  ]
202
211
 
203
212
 
213
+ @embeddings.conditional_return_type
214
+ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
215
+ if dimensions is None:
216
+ if model not in _embedding_dimensions_cache:
217
+ # TODO: find some other way to retrieve a sample
218
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
219
+ dimensions = _embedding_dimensions_cache.get(model, None)
220
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType(), nullable=False)
221
+
222
+
204
223
  #####################################
205
224
  # Images Endpoints
206
225
 
@@ -232,6 +251,20 @@ def image_generations(
232
251
  return img
233
252
 
234
253
 
254
+ @image_generations.conditional_return_type
255
+ def _(size: Optional[str] = None) -> ts.ImageType:
256
+ if size is None:
257
+ return ts.ImageType(size=(1024, 1024))
258
+ x_pos = size.find('x')
259
+ if x_pos == -1:
260
+ return ts.ImageType()
261
+ try:
262
+ width, height = int(size[:x_pos]), int(size[x_pos + 1:])
263
+ except ValueError:
264
+ return ts.ImageType()
265
+ return ts.ImageType(size=(width, height))
266
+
267
+
235
268
  #####################################
236
269
  # Moderations Endpoints
237
270
 
@@ -1,16 +1,12 @@
1
- from typing import Dict, Any, Tuple, Optional
1
+ from typing import Tuple, Optional
2
2
 
3
3
  import PIL.Image
4
+ from PIL.Image import Dither
4
5
 
5
- from pixeltable.type_system import FloatType, ImageType, IntType, ArrayType, ColumnType, StringType, JsonType, BoolType
6
6
  import pixeltable.func as func
7
+ from pixeltable.type_system import FloatType, ImageType, IntType, ArrayType, ColumnType, StringType, JsonType
7
8
 
8
9
 
9
- def _caller_return_type(bound_args: Optional[Dict[str, Any]]) -> ColumnType:
10
- if bound_args is None:
11
- return ImageType()
12
- return bound_args['self'].col_type
13
-
14
10
  @func.udf(
15
11
  py_fn=PIL.Image.alpha_composite, return_type=ImageType(), param_types=[ImageType(), ImageType()])
16
12
  def alpha_composite(im1: PIL.Image.Image, im2: PIL.Image.Image) -> PIL.Image.Image:
@@ -28,71 +24,78 @@ def composite(image1: PIL.Image.Image, image2: PIL.Image.Image, mask: PIL.Image.
28
24
  # PIL.Image.Image methods
29
25
 
30
26
  # Image.convert()
31
- def _convert_return_type(bound_args: Dict[str, Any]) -> ColumnType:
32
- if bound_args is None:
33
- return ImageType()
34
- assert 'self' in bound_args
35
- assert 'mode' in bound_args
36
- img_type = bound_args['self'].col_type
37
- return ImageType(size=img_type.size, mode=bound_args['mode'])
38
- @func.udf(return_type=_convert_return_type, param_types=[ImageType(), StringType()])
27
+ @func.udf(param_types=[ImageType(), StringType()])
39
28
  def convert(self: PIL.Image.Image, mode: str) -> PIL.Image.Image:
40
29
  return self.convert(mode)
41
30
 
31
+
32
+ @convert.conditional_return_type
33
+ def _(self: PIL.Image.Image, mode: str) -> ColumnType:
34
+ input_type = self.col_type
35
+ assert input_type.is_image_type()
36
+ return ImageType(size=input_type.size, mode=mode, nullable=input_type.nullable)
37
+
38
+
42
39
  # Image.crop()
43
- def _crop_return_type(bound_args: Dict[str, Any]) -> ColumnType:
44
- if bound_args is None:
45
- return ImageType()
46
- img_type = bound_args['self'].col_type
47
- box = bound_args['box']
48
- if isinstance(box, list) and all(isinstance(x, int) for x in box):
49
- return ImageType(size=(box[2] - box[0], box[3] - box[1]), mode=img_type.mode)
50
- return ImageType() # we can't compute the size statically
51
40
  @func.udf(
52
- py_fn=PIL.Image.Image.crop, return_type=_crop_return_type,
41
+ py_fn=PIL.Image.Image.crop,
53
42
  param_types=[ImageType(), ArrayType((4,), dtype=IntType())])
54
43
  def crop(self: PIL.Image.Image, box: Tuple[int, int, int, int]) -> PIL.Image.Image:
55
44
  pass
56
45
 
46
+ @crop.conditional_return_type
47
+ def _(self: PIL.Image.Image, box: Tuple[int, int, int, int]) -> ColumnType:
48
+ input_type = self.col_type
49
+ assert input_type.is_image_type()
50
+ if isinstance(box, list) and all(isinstance(x, int) for x in box):
51
+ return ImageType(size=(box[2] - box[0], box[3] - box[1]), mode=input_type.mode, nullable=input_type.nullable)
52
+ return ImageType(mode=input_type.mode, nullable=input_type.nullable) # we can't compute the size statically
53
+
57
54
  # Image.getchannel()
58
- def _getchannel_return_type(bound_args: Dict[str, Any]) -> ColumnType:
59
- if bound_args is None:
60
- return ImageType()
61
- img_type = bound_args['self'].col_type
62
- return ImageType(size=img_type.size, mode='L')
63
- @func.udf(
64
- py_fn=PIL.Image.Image.getchannel, return_type=_getchannel_return_type, param_types=[ImageType(), IntType()])
55
+ @func.udf(py_fn=PIL.Image.Image.getchannel, param_types=[ImageType(), IntType()])
65
56
  def getchannel(self: PIL.Image.Image, channel: int) -> PIL.Image.Image:
66
57
  pass
67
58
 
59
+ @getchannel.conditional_return_type
60
+ def _(self: PIL.Image.Image) -> ColumnType:
61
+ input_type = self.col_type
62
+ assert input_type.is_image_type()
63
+ return ImageType(size=input_type.size, mode='L', nullable=input_type.nullable)
64
+
65
+
68
66
  # Image.resize()
69
- def resize_return_type(bound_args: Dict[str, Any]) -> ColumnType:
70
- if bound_args is None:
71
- return ImageType()
72
- assert 'size' in bound_args
73
- return ImageType(size=bound_args['size'])
74
- @func.udf(return_type=resize_return_type, param_types=[ImageType(), ArrayType((2, ), dtype=IntType())])
67
+ @func.udf(param_types=[ImageType(), ArrayType((2, ), dtype=IntType())])
75
68
  def resize(self: PIL.Image.Image, size: Tuple[int, int]) -> PIL.Image.Image:
76
69
  return self.resize(size)
77
70
 
71
+ @resize.conditional_return_type
72
+ def _(self: PIL.Image.Image, size: Tuple[int, int]) -> ColumnType:
73
+ input_type = self.col_type
74
+ assert input_type.is_image_type()
75
+ return ImageType(size=size, mode=input_type.mode, nullable=input_type.nullable)
76
+
78
77
  # Image.rotate()
79
- @func.udf(return_type=ImageType(), param_types=[ImageType(), IntType()])
78
+ @func.udf(param_types=[ImageType(), IntType()])
80
79
  def rotate(self: PIL.Image.Image, angle: int) -> PIL.Image.Image:
81
80
  return self.rotate(angle)
82
81
 
83
- # Image.transform()
84
- @func.udf(return_type= _caller_return_type, param_types=[ImageType(), ArrayType((2,), dtype=IntType()), IntType()])
85
- def transform(self: PIL.Image.Image, size: Tuple[int, int], method: int) -> PIL.Image.Image:
86
- return self.transform(size, method)
82
+ @func.udf(py_fn=PIL.Image.Image.effect_spread, param_types=[ImageType(), IntType()])
83
+ def effect_spread(self: PIL.Image.Image, distance: int) -> PIL.Image.Image:
84
+ pass
87
85
 
88
- @func.udf(
89
- py_fn=PIL.Image.Image.effect_spread, return_type=_caller_return_type, param_types=[ImageType(), FloatType()])
90
- def effect_spread(self: PIL.Image.Image, distance: float) -> PIL.Image.Image:
86
+ @func.udf(py_fn=PIL.Image.Image.transpose, param_types=[ImageType(), IntType()])
87
+ def transpose(self: PIL.Image.Image, method: int) -> PIL.Image.Image:
91
88
  pass
92
89
 
90
+ @rotate.conditional_return_type
91
+ @effect_spread.conditional_return_type
92
+ @transpose.conditional_return_type
93
+ def _(self: PIL.Image.Image) -> ColumnType:
94
+ return self.col_type
95
+
93
96
  @func.udf(
94
97
  py_fn=PIL.Image.Image.entropy, return_type=FloatType(), param_types=[ImageType(), ImageType(), JsonType()])
95
- def entropy(self: PIL.Image.Image, mask: PIL.Image.Image, histogram: Dict) -> float:
98
+ def entropy(self: PIL.Image.Image, mask: PIL.Image.Image, extrema: Optional[list] = None) -> float:
96
99
  pass
97
100
 
98
101
  @func.udf(py_fn=PIL.Image.Image.getbands, return_type=JsonType(), param_types=[ImageType()])
@@ -103,8 +106,7 @@ def getbands(self: PIL.Image.Image) -> Tuple[str]:
103
106
  def getbbox(self: PIL.Image.Image) -> Tuple[int, int, int, int]:
104
107
  pass
105
108
 
106
- @func.udf(
107
- py_fn=PIL.Image.Image.getcolors, return_type=JsonType(), param_types=[ImageType(), IntType()])
109
+ @func.udf(py_fn=PIL.Image.Image.getcolors, return_type=JsonType(), param_types=[ImageType(), IntType()])
108
110
  def getcolors(self: PIL.Image.Image, maxcolors: int) -> Tuple[Tuple[int, int, int], int]:
109
111
  pass
110
112
 
@@ -114,37 +116,32 @@ def getextrema(self: PIL.Image.Image) -> Tuple[int, int]:
114
116
 
115
117
  @func.udf(
116
118
  py_fn=PIL.Image.Image.getpalette, return_type=JsonType(), param_types=[ImageType(), StringType()])
117
- def getpalette(self: PIL.Image.Image, mode: str) -> Tuple[int]:
119
+ def getpalette(self: PIL.Image.Image, mode: Optional[str] = None) -> Tuple[int]:
118
120
  pass
119
121
 
120
122
  @func.udf(
121
- py_fn=PIL.Image.Image.getpixel, return_type=JsonType(), param_types=[ImageType(), ArrayType((2,), dtype=IntType())])
122
- def getpixel(self: PIL.Image.Image, xy: Tuple[int, int]) -> Tuple[int]:
123
- pass
123
+ return_type=JsonType(), param_types=[ImageType(), ArrayType((2,), dtype=IntType())])
124
+ def getpixel(self: PIL.Image.Image, xy: tuple[int, int]) -> Tuple[int]:
125
+ # `xy` will be a list; `tuple(xy)` is necessary for pillow 9 compatibility
126
+ return self.getpixel(tuple(xy))
124
127
 
125
- @func.udf(
126
- py_fn=PIL.Image.Image.getprojection, return_type=JsonType(), param_types=[ImageType()])
128
+ @func.udf(py_fn=PIL.Image.Image.getprojection, return_type=JsonType(), param_types=[ImageType()])
127
129
  def getprojection(self: PIL.Image.Image) -> Tuple[int]:
128
130
  pass
129
131
 
130
- @func.udf(
131
- py_fn=PIL.Image.Image.histogram, return_type=JsonType(), param_types=[ImageType(), ImageType(), JsonType()])
132
- def histogram(self: PIL.Image.Image, mask: PIL.Image.Image, histogram: Dict) -> Tuple[int]:
132
+ @func.udf(py_fn=PIL.Image.Image.histogram, return_type=JsonType(), param_types=[ImageType(), ImageType(), JsonType()])
133
+ def histogram(self: PIL.Image.Image, mask: PIL.Image.Image, extrema: Optional[list] = None) -> Tuple[int]:
133
134
  pass
134
135
 
135
136
  @func.udf(
136
137
  py_fn=PIL.Image.Image.quantize, return_type=ImageType(),
137
138
  param_types=[ImageType(), IntType(), IntType(nullable=True), IntType(), IntType(nullable=True), IntType()])
138
139
  def quantize(
139
- self: PIL.Image.Image, colors: int, method: int, kmeans: int, palette: int, dither: int) -> PIL.Image.Image:
140
+ self: PIL.Image.Image, colors: int = 256, method: Optional[int] = None, kmeans: int = 0,
141
+ palette: Optional[int] = None, dither: int = Dither.FLOYDSTEINBERG) -> PIL.Image.Image:
140
142
  pass
141
143
 
142
144
  @func.udf(
143
145
  py_fn=PIL.Image.Image.reduce, return_type=ImageType(), param_types=[ImageType(), IntType(), JsonType()])
144
- def reduce(self: PIL.Image.Image, factor: int, filter: Tuple[int]) -> PIL.Image.Image:
145
- pass
146
-
147
- @func.udf(
148
- py_fn=PIL.Image.Image.transpose, return_type=_caller_return_type, param_types=[ImageType(), IntType()])
149
- def transpose(self: PIL.Image.Image, method: int) -> PIL.Image.Image:
146
+ def reduce(self: PIL.Image.Image, factor: int, box: Optional[Tuple[int]]) -> PIL.Image.Image:
150
147
  pass
@@ -85,6 +85,18 @@ def chat_completions(
85
85
  ).dict()
86
86
 
87
87
 
88
+ _embedding_dimensions_cache = {
89
+ 'togethercomputer/m2-bert-80M-2k-retrieval': 768,
90
+ 'togethercomputer/m2-bert-80M-8k-retrieval': 768,
91
+ 'togethercomputer/m2-bert-80M-32k-retrieval': 768,
92
+ 'WhereIsAI/UAE-Large-V1': 1024,
93
+ 'BAAI/bge-large-en-v1.5': 1024,
94
+ 'BAAI/bge-base-en-v1.5': 768,
95
+ 'sentence-transformers/msmarco-bert-base-dot-v5': 768,
96
+ 'bert-base-uncased': 768,
97
+ }
98
+
99
+
88
100
  @pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
89
101
  def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
90
102
  result = together_client().embeddings.create(input=input, model=model)
@@ -94,6 +106,15 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
94
106
  ]
95
107
 
96
108
 
109
+ @embeddings.conditional_return_type
110
+ def _(model: str) -> pxt.ArrayType:
111
+ if model not in _embedding_dimensions_cache:
112
+ # TODO: find some other way to retrieve a sample
113
+ return pxt.ArrayType((None,), dtype=pxt.FloatType())
114
+ dimensions = _embedding_dimensions_cache[model]
115
+ return pxt.ArrayType((dimensions,), dtype=pxt.FloatType())
116
+
117
+
97
118
  @pxt.udf
98
119
  def image_generations(
99
120
  prompt: str,