pixeltable 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (99) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +31 -50
  4. pixeltable/catalog/insertable_table.py +7 -6
  5. pixeltable/catalog/table.py +171 -57
  6. pixeltable/catalog/table_version.py +417 -140
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/dataframe.py +239 -121
  9. pixeltable/env.py +82 -16
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/cache_prefetch_node.py +1 -1
  12. pixeltable/exec/data_row_batch.py +6 -7
  13. pixeltable/exec/expr_eval_node.py +28 -28
  14. pixeltable/exec/in_memory_data_node.py +11 -7
  15. pixeltable/exec/sql_scan_node.py +7 -6
  16. pixeltable/exprs/__init__.py +4 -3
  17. pixeltable/exprs/column_ref.py +9 -0
  18. pixeltable/exprs/comparison.py +3 -3
  19. pixeltable/exprs/data_row.py +5 -1
  20. pixeltable/exprs/expr.py +15 -7
  21. pixeltable/exprs/function_call.py +17 -15
  22. pixeltable/exprs/image_member_access.py +9 -28
  23. pixeltable/exprs/in_predicate.py +96 -0
  24. pixeltable/exprs/inline_array.py +13 -11
  25. pixeltable/exprs/inline_dict.py +15 -13
  26. pixeltable/exprs/literal.py +16 -4
  27. pixeltable/exprs/row_builder.py +15 -41
  28. pixeltable/exprs/similarity_expr.py +65 -0
  29. pixeltable/ext/__init__.py +5 -0
  30. pixeltable/ext/functions/yolox.py +92 -0
  31. pixeltable/func/__init__.py +0 -2
  32. pixeltable/func/aggregate_function.py +18 -15
  33. pixeltable/func/callable_function.py +57 -13
  34. pixeltable/func/expr_template_function.py +20 -3
  35. pixeltable/func/function.py +35 -4
  36. pixeltable/func/globals.py +24 -14
  37. pixeltable/func/signature.py +23 -27
  38. pixeltable/func/udf.py +13 -12
  39. pixeltable/functions/__init__.py +8 -8
  40. pixeltable/functions/eval.py +7 -8
  41. pixeltable/functions/huggingface.py +64 -17
  42. pixeltable/functions/openai.py +36 -3
  43. pixeltable/functions/pil/image.py +61 -64
  44. pixeltable/functions/together.py +21 -0
  45. pixeltable/functions/util.py +11 -0
  46. pixeltable/globals.py +425 -0
  47. pixeltable/index/__init__.py +2 -0
  48. pixeltable/index/base.py +51 -0
  49. pixeltable/index/embedding_index.py +168 -0
  50. pixeltable/io/__init__.py +3 -0
  51. pixeltable/{utils → io}/hf_datasets.py +48 -17
  52. pixeltable/io/pandas.py +148 -0
  53. pixeltable/{utils → io}/parquet.py +58 -33
  54. pixeltable/iterators/__init__.py +1 -1
  55. pixeltable/iterators/base.py +4 -0
  56. pixeltable/iterators/document.py +218 -97
  57. pixeltable/iterators/video.py +8 -9
  58. pixeltable/metadata/__init__.py +7 -3
  59. pixeltable/metadata/converters/convert_12.py +3 -0
  60. pixeltable/metadata/converters/convert_13.py +41 -0
  61. pixeltable/metadata/schema.py +45 -22
  62. pixeltable/plan.py +15 -51
  63. pixeltable/store.py +38 -41
  64. pixeltable/tool/create_test_db_dump.py +39 -4
  65. pixeltable/type_system.py +47 -96
  66. pixeltable/utils/documents.py +42 -12
  67. pixeltable/utils/http_server.py +70 -0
  68. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/METADATA +14 -10
  69. pixeltable-0.2.6.dist-info/RECORD +119 -0
  70. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  71. pixeltable/client.py +0 -604
  72. pixeltable/exprs/image_similarity_predicate.py +0 -58
  73. pixeltable/func/batched_function.py +0 -53
  74. pixeltable/tests/conftest.py +0 -177
  75. pixeltable/tests/functions/test_fireworks.py +0 -42
  76. pixeltable/tests/functions/test_functions.py +0 -60
  77. pixeltable/tests/functions/test_huggingface.py +0 -158
  78. pixeltable/tests/functions/test_openai.py +0 -152
  79. pixeltable/tests/functions/test_together.py +0 -111
  80. pixeltable/tests/test_audio.py +0 -65
  81. pixeltable/tests/test_catalog.py +0 -27
  82. pixeltable/tests/test_client.py +0 -21
  83. pixeltable/tests/test_component_view.py +0 -370
  84. pixeltable/tests/test_dataframe.py +0 -439
  85. pixeltable/tests/test_dirs.py +0 -107
  86. pixeltable/tests/test_document.py +0 -120
  87. pixeltable/tests/test_exprs.py +0 -805
  88. pixeltable/tests/test_function.py +0 -324
  89. pixeltable/tests/test_migration.py +0 -43
  90. pixeltable/tests/test_nos.py +0 -54
  91. pixeltable/tests/test_snapshot.py +0 -208
  92. pixeltable/tests/test_table.py +0 -1267
  93. pixeltable/tests/test_transactional_directory.py +0 -42
  94. pixeltable/tests/test_types.py +0 -22
  95. pixeltable/tests/test_video.py +0 -159
  96. pixeltable/tests/test_view.py +0 -530
  97. pixeltable/tests/utils.py +0 -408
  98. pixeltable-0.2.4.dist-info/RECORD +0 -132
  99. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
@@ -23,8 +23,8 @@ def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
23
23
  return expr
24
24
 
25
25
  @func.uda(
26
- update_types=[IntType()], value_type=IntType(), name='sum', allows_window=True, requires_order_by=False)
27
- class SumAggregator(func.Aggregator):
26
+ update_types=[IntType()], value_type=IntType(), allows_window=True, requires_order_by=False)
27
+ class sum(func.Aggregator):
28
28
  def __init__(self):
29
29
  self.sum: Union[int, float] = 0
30
30
  def update(self, val: Union[int, float]) -> None:
@@ -35,8 +35,8 @@ class SumAggregator(func.Aggregator):
35
35
 
36
36
 
37
37
  @func.uda(
38
- update_types=[IntType()], value_type=IntType(), name='count', allows_window = True, requires_order_by = False)
39
- class CountAggregator(func.Aggregator):
38
+ update_types=[IntType()], value_type=IntType(), allows_window = True, requires_order_by = False)
39
+ class count(func.Aggregator):
40
40
  def __init__(self):
41
41
  self.count = 0
42
42
  def update(self, val: int) -> None:
@@ -47,8 +47,8 @@ class CountAggregator(func.Aggregator):
47
47
 
48
48
 
49
49
  @func.uda(
50
- update_types=[IntType()], value_type=FloatType(), name='mean', allows_window=False, requires_order_by=False)
51
- class MeanAggregator(func.Aggregator):
50
+ update_types=[IntType()], value_type=FloatType(), allows_window=False, requires_order_by=False)
51
+ class mean(func.Aggregator):
52
52
  def __init__(self):
53
53
  self.sum = 0
54
54
  self.count = 0
@@ -63,9 +63,9 @@ class MeanAggregator(func.Aggregator):
63
63
 
64
64
 
65
65
  @func.uda(
66
- init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(), name='make_video',
66
+ init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
67
67
  requires_order_by=True, allows_window=False)
68
- class VideoAggregator(func.Aggregator):
68
+ class make_video(func.Aggregator):
69
69
  def __init__(self, fps: int = 25):
70
70
  """follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
71
71
  self.container: Optional[av.container.OutputContainer] = None
@@ -1,4 +1,3 @@
1
- from __future__ import annotations
2
1
  from typing import List, Tuple, Dict
3
2
  from collections import defaultdict
4
3
  import sys
@@ -157,16 +156,16 @@ def calculate_image_tpfp(
157
156
  ts.JsonType(nullable=False)
158
157
  ])
159
158
  def eval_detections(
160
- pred_bboxes: List[List[int]], pred_classes: List[int], pred_scores: List[float],
161
- gt_bboxes: List[List[int]], gt_classes: List[int]
159
+ pred_bboxes: List[List[int]], pred_labels: List[int], pred_scores: List[float],
160
+ gt_bboxes: List[List[int]], gt_labels: List[int]
162
161
  ) -> Dict:
163
- class_idxs = list(set(pred_classes + gt_classes))
162
+ class_idxs = list(set(pred_labels + gt_labels))
164
163
  result: List[Dict] = []
165
164
  pred_bboxes_arr = np.asarray(pred_bboxes)
166
- pred_classes_arr = np.asarray(pred_classes)
165
+ pred_classes_arr = np.asarray(pred_labels)
167
166
  pred_scores_arr = np.asarray(pred_scores)
168
167
  gt_bboxes_arr = np.asarray(gt_bboxes)
169
- gt_classes_arr = np.asarray(gt_classes)
168
+ gt_classes_arr = np.asarray(gt_labels)
170
169
  for class_idx in class_idxs:
171
170
  pred_filter = pred_classes_arr == class_idx
172
171
  gt_filter = gt_classes_arr == class_idx
@@ -181,8 +180,8 @@ def eval_detections(
181
180
  return result
182
181
 
183
182
  @func.uda(
184
- update_types=[ts.JsonType()], value_type=ts.JsonType(), name='mean_ap', allows_std_agg=True, allows_window=False)
185
- class MeanAPAggregator:
183
+ update_types=[ts.JsonType()], value_type=ts.JsonType(), allows_std_agg=True, allows_window=False)
184
+ class mean_ap(func.Aggregator):
186
185
  def __init__(self):
187
186
  self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
188
187
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable
1
+ from typing import Callable, TypeVar, Optional
2
2
 
3
3
  import PIL.Image
4
4
  import numpy as np
@@ -7,10 +7,13 @@ import pixeltable as pxt
7
7
  import pixeltable.env as env
8
8
  import pixeltable.type_system as ts
9
9
  from pixeltable.func import Batch
10
+ from pixeltable.functions.util import resolve_torch_device
10
11
 
11
12
 
12
13
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
13
- def sentence_transformer(sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False) -> Batch[np.ndarray]:
14
+ def sentence_transformer(
15
+ sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
16
+ ) -> Batch[np.ndarray]:
14
17
  env.Env.get().require_package('sentence_transformers')
15
18
  from sentence_transformers import SentenceTransformer
16
19
 
@@ -20,6 +23,16 @@ def sentence_transformer(sentences: Batch[str], *, model_id: str, normalize_embe
20
23
  return [array[i] for i in range(array.shape[0])]
21
24
 
22
25
 
26
+ @sentence_transformer.conditional_return_type
27
+ def _(model_id: str) -> ts.ArrayType:
28
+ try:
29
+ from sentence_transformers import SentenceTransformer
30
+ model = _lookup_model(model_id, SentenceTransformer)
31
+ return ts.ArrayType((model.get_sentence_embedding_dimension(),), dtype=ts.FloatType(), nullable=False)
32
+ except ImportError:
33
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
34
+
35
+
23
36
  @pxt.udf
24
37
  def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
25
38
  env.Env.get().require_package('sentence_transformers')
@@ -56,41 +69,66 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
56
69
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
57
70
  def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
58
71
  env.Env.get().require_package('transformers')
72
+ device = resolve_torch_device('auto')
73
+ import torch
59
74
  from transformers import CLIPModel, CLIPProcessor
60
75
 
61
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
76
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
62
77
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
63
78
 
64
- inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
65
- embeddings = model.get_text_features(**inputs).detach().numpy()
79
+ with torch.no_grad():
80
+ inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
81
+ embeddings = model.get_text_features(**inputs.to(device)).detach().to('cpu').numpy()
82
+
66
83
  return [embeddings[i] for i in range(embeddings.shape[0])]
67
84
 
68
85
 
69
86
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
70
87
  def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
71
88
  env.Env.get().require_package('transformers')
89
+ device = resolve_torch_device('auto')
90
+ import torch
72
91
  from transformers import CLIPModel, CLIPProcessor
73
92
 
74
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
93
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
75
94
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
76
95
 
77
- inputs = processor(images=image, return_tensors='pt', padding=True)
78
- embeddings = model.get_image_features(**inputs).detach().numpy()
96
+ with torch.no_grad():
97
+ inputs = processor(images=image, return_tensors='pt', padding=True)
98
+ embeddings = model.get_image_features(**inputs.to(device)).detach().to('cpu').numpy()
99
+
79
100
  return [embeddings[i] for i in range(embeddings.shape[0])]
80
101
 
81
102
 
82
- @pxt.udf(batch_size=32)
103
+ @clip_text.conditional_return_type
104
+ @clip_image.conditional_return_type
105
+ def _(model_id: str) -> ts.ArrayType:
106
+ try:
107
+ from transformers import CLIPModel
108
+ model = _lookup_model(model_id, CLIPModel.from_pretrained)
109
+ return ts.ArrayType((model.config.projection_dim,), dtype=ts.FloatType(), nullable=False)
110
+ except ImportError:
111
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
112
+
113
+
114
+ @pxt.udf(batch_size=4)
83
115
  def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
84
116
  env.Env.get().require_package('transformers')
117
+ device = resolve_torch_device('auto')
118
+ import torch
85
119
  from transformers import DetrImageProcessor, DetrForObjectDetection
86
120
 
87
- model = _lookup_model(model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'))
121
+ model = _lookup_model(
122
+ model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device)
88
123
  processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
89
124
 
90
- inputs = processor(images=image, return_tensors='pt')
91
- outputs = model(**inputs)
125
+ with torch.no_grad():
126
+ inputs = processor(images=image, return_tensors='pt')
127
+ outputs = model(**inputs.to(device))
128
+ results = processor.post_process_object_detection(
129
+ outputs, threshold=threshold, target_sizes=[(img.height, img.width) for img in image]
130
+ )
92
131
 
93
- results = processor.post_process_object_detection(outputs, threshold=threshold)
94
132
  return [
95
133
  {
96
134
  'scores': [score.item() for score in result['scores']],
@@ -102,14 +140,23 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
102
140
  ]
103
141
 
104
142
 
105
- def _lookup_model(model_id: str, create: Callable) -> Any:
106
- key = (model_id, create) # For safety, include the `create` callable in the cache key
143
+ T = TypeVar('T')
144
+
145
+
146
+ def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
147
+ from torch import nn
148
+ key = (model_id, create, device) # For safety, include the `create` callable in the cache key
107
149
  if key not in _model_cache:
108
- _model_cache[key] = create(model_id)
150
+ model = create(model_id)
151
+ if device is not None:
152
+ model.to(device)
153
+ if isinstance(model, nn.Module):
154
+ model.eval()
155
+ _model_cache[key] = model
109
156
  return _model_cache[key]
110
157
 
111
158
 
112
- def _lookup_processor(model_id: str, create: Callable) -> Any:
159
+ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
113
160
  key = (model_id, create) # For safety, include the `create` callable in the cache key
114
161
  if key not in _processor_cache:
115
162
  _processor_cache[key] = create(model_id)
@@ -26,8 +26,8 @@ def openai_client() -> openai.OpenAI:
26
26
  def _retry(fn: Callable) -> Callable:
27
27
  return tenacity.retry(
28
28
  retry=tenacity.retry_if_exception_type(openai.RateLimitError),
29
- wait=tenacity.wait_random_exponential(min=1, max=60),
30
- stop=tenacity.stop_after_attempt(6)
29
+ wait=tenacity.wait_random_exponential(multiplier=3, max=180),
30
+ stop=tenacity.stop_after_attempt(20)
31
31
  )(fn)
32
32
 
33
33
 
@@ -53,7 +53,7 @@ def speech(
53
53
  )
54
54
  ext = response_format or 'mp3'
55
55
  output_filename = str(env.Env.get().tmp_dir / f"{uuid.uuid4()}.{ext}")
56
- content.stream_to_file(output_filename, chunk_size=1 << 20)
56
+ content.write_to_file(output_filename)
57
57
  return output_filename
58
58
 
59
59
 
@@ -181,17 +181,26 @@ def vision(
181
181
  #####################################
182
182
  # Embeddings Endpoints
183
183
 
184
+ _embedding_dimensions_cache: dict[str, int] = {
185
+ 'text-embedding-ada-002': 1536,
186
+ 'text-embedding-3-small': 1536,
187
+ 'text-embedding-3-large': 3072,
188
+ }
189
+
190
+
184
191
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
185
192
  @_retry
186
193
  def embeddings(
187
194
  input: Batch[str],
188
195
  *,
189
196
  model: str,
197
+ dimensions: Optional[int] = None,
190
198
  user: Optional[str] = None
191
199
  ) -> Batch[np.ndarray]:
192
200
  result = openai_client().embeddings.create(
193
201
  input=input,
194
202
  model=model,
203
+ dimensions=_opt(dimensions),
195
204
  user=_opt(user),
196
205
  encoding_format='float'
197
206
  )
@@ -201,6 +210,16 @@ def embeddings(
201
210
  ]
202
211
 
203
212
 
213
+ @embeddings.conditional_return_type
214
+ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
215
+ if dimensions is None:
216
+ if model not in _embedding_dimensions_cache:
217
+ # TODO: find some other way to retrieve a sample
218
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
219
+ dimensions = _embedding_dimensions_cache.get(model, None)
220
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType(), nullable=False)
221
+
222
+
204
223
  #####################################
205
224
  # Images Endpoints
206
225
 
@@ -232,6 +251,20 @@ def image_generations(
232
251
  return img
233
252
 
234
253
 
254
+ @image_generations.conditional_return_type
255
+ def _(size: Optional[str] = None) -> ts.ImageType:
256
+ if size is None:
257
+ return ts.ImageType(size=(1024, 1024))
258
+ x_pos = size.find('x')
259
+ if x_pos == -1:
260
+ return ts.ImageType()
261
+ try:
262
+ width, height = int(size[:x_pos]), int(size[x_pos + 1:])
263
+ except ValueError:
264
+ return ts.ImageType()
265
+ return ts.ImageType(size=(width, height))
266
+
267
+
235
268
  #####################################
236
269
  # Moderations Endpoints
237
270
 
@@ -1,16 +1,12 @@
1
- from typing import Dict, Any, Tuple, Optional
1
+ from typing import Tuple, Optional
2
2
 
3
3
  import PIL.Image
4
+ from PIL.Image import Dither
4
5
 
5
- from pixeltable.type_system import FloatType, ImageType, IntType, ArrayType, ColumnType, StringType, JsonType, BoolType
6
6
  import pixeltable.func as func
7
+ from pixeltable.type_system import FloatType, ImageType, IntType, ArrayType, ColumnType, StringType, JsonType
7
8
 
8
9
 
9
- def _caller_return_type(bound_args: Optional[Dict[str, Any]]) -> ColumnType:
10
- if bound_args is None:
11
- return ImageType()
12
- return bound_args['self'].col_type
13
-
14
10
  @func.udf(
15
11
  py_fn=PIL.Image.alpha_composite, return_type=ImageType(), param_types=[ImageType(), ImageType()])
16
12
  def alpha_composite(im1: PIL.Image.Image, im2: PIL.Image.Image) -> PIL.Image.Image:
@@ -28,71 +24,78 @@ def composite(image1: PIL.Image.Image, image2: PIL.Image.Image, mask: PIL.Image.
28
24
  # PIL.Image.Image methods
29
25
 
30
26
  # Image.convert()
31
- def _convert_return_type(bound_args: Dict[str, Any]) -> ColumnType:
32
- if bound_args is None:
33
- return ImageType()
34
- assert 'self' in bound_args
35
- assert 'mode' in bound_args
36
- img_type = bound_args['self'].col_type
37
- return ImageType(size=img_type.size, mode=bound_args['mode'])
38
- @func.udf(return_type=_convert_return_type, param_types=[ImageType(), StringType()])
27
+ @func.udf(param_types=[ImageType(), StringType()])
39
28
  def convert(self: PIL.Image.Image, mode: str) -> PIL.Image.Image:
40
29
  return self.convert(mode)
41
30
 
31
+
32
+ @convert.conditional_return_type
33
+ def _(self: PIL.Image.Image, mode: str) -> ColumnType:
34
+ input_type = self.col_type
35
+ assert input_type.is_image_type()
36
+ return ImageType(size=input_type.size, mode=mode, nullable=input_type.nullable)
37
+
38
+
42
39
  # Image.crop()
43
- def _crop_return_type(bound_args: Dict[str, Any]) -> ColumnType:
44
- if bound_args is None:
45
- return ImageType()
46
- img_type = bound_args['self'].col_type
47
- box = bound_args['box']
48
- if isinstance(box, list) and all(isinstance(x, int) for x in box):
49
- return ImageType(size=(box[2] - box[0], box[3] - box[1]), mode=img_type.mode)
50
- return ImageType() # we can't compute the size statically
51
40
  @func.udf(
52
- py_fn=PIL.Image.Image.crop, return_type=_crop_return_type,
41
+ py_fn=PIL.Image.Image.crop,
53
42
  param_types=[ImageType(), ArrayType((4,), dtype=IntType())])
54
43
  def crop(self: PIL.Image.Image, box: Tuple[int, int, int, int]) -> PIL.Image.Image:
55
44
  pass
56
45
 
46
+ @crop.conditional_return_type
47
+ def _(self: PIL.Image.Image, box: Tuple[int, int, int, int]) -> ColumnType:
48
+ input_type = self.col_type
49
+ assert input_type.is_image_type()
50
+ if isinstance(box, list) and all(isinstance(x, int) for x in box):
51
+ return ImageType(size=(box[2] - box[0], box[3] - box[1]), mode=input_type.mode, nullable=input_type.nullable)
52
+ return ImageType(mode=input_type.mode, nullable=input_type.nullable) # we can't compute the size statically
53
+
57
54
  # Image.getchannel()
58
- def _getchannel_return_type(bound_args: Dict[str, Any]) -> ColumnType:
59
- if bound_args is None:
60
- return ImageType()
61
- img_type = bound_args['self'].col_type
62
- return ImageType(size=img_type.size, mode='L')
63
- @func.udf(
64
- py_fn=PIL.Image.Image.getchannel, return_type=_getchannel_return_type, param_types=[ImageType(), IntType()])
55
+ @func.udf(py_fn=PIL.Image.Image.getchannel, param_types=[ImageType(), IntType()])
65
56
  def getchannel(self: PIL.Image.Image, channel: int) -> PIL.Image.Image:
66
57
  pass
67
58
 
59
+ @getchannel.conditional_return_type
60
+ def _(self: PIL.Image.Image) -> ColumnType:
61
+ input_type = self.col_type
62
+ assert input_type.is_image_type()
63
+ return ImageType(size=input_type.size, mode='L', nullable=input_type.nullable)
64
+
65
+
68
66
  # Image.resize()
69
- def resize_return_type(bound_args: Dict[str, Any]) -> ColumnType:
70
- if bound_args is None:
71
- return ImageType()
72
- assert 'size' in bound_args
73
- return ImageType(size=bound_args['size'])
74
- @func.udf(return_type=resize_return_type, param_types=[ImageType(), ArrayType((2, ), dtype=IntType())])
67
+ @func.udf(param_types=[ImageType(), ArrayType((2, ), dtype=IntType())])
75
68
  def resize(self: PIL.Image.Image, size: Tuple[int, int]) -> PIL.Image.Image:
76
69
  return self.resize(size)
77
70
 
71
+ @resize.conditional_return_type
72
+ def _(self: PIL.Image.Image, size: Tuple[int, int]) -> ColumnType:
73
+ input_type = self.col_type
74
+ assert input_type.is_image_type()
75
+ return ImageType(size=size, mode=input_type.mode, nullable=input_type.nullable)
76
+
78
77
  # Image.rotate()
79
- @func.udf(return_type=ImageType(), param_types=[ImageType(), IntType()])
78
+ @func.udf(param_types=[ImageType(), IntType()])
80
79
  def rotate(self: PIL.Image.Image, angle: int) -> PIL.Image.Image:
81
80
  return self.rotate(angle)
82
81
 
83
- # Image.transform()
84
- @func.udf(return_type= _caller_return_type, param_types=[ImageType(), ArrayType((2,), dtype=IntType()), IntType()])
85
- def transform(self: PIL.Image.Image, size: Tuple[int, int], method: int) -> PIL.Image.Image:
86
- return self.transform(size, method)
82
+ @func.udf(py_fn=PIL.Image.Image.effect_spread, param_types=[ImageType(), IntType()])
83
+ def effect_spread(self: PIL.Image.Image, distance: int) -> PIL.Image.Image:
84
+ pass
87
85
 
88
- @func.udf(
89
- py_fn=PIL.Image.Image.effect_spread, return_type=_caller_return_type, param_types=[ImageType(), FloatType()])
90
- def effect_spread(self: PIL.Image.Image, distance: float) -> PIL.Image.Image:
86
+ @func.udf(py_fn=PIL.Image.Image.transpose, param_types=[ImageType(), IntType()])
87
+ def transpose(self: PIL.Image.Image, method: int) -> PIL.Image.Image:
91
88
  pass
92
89
 
90
+ @rotate.conditional_return_type
91
+ @effect_spread.conditional_return_type
92
+ @transpose.conditional_return_type
93
+ def _(self: PIL.Image.Image) -> ColumnType:
94
+ return self.col_type
95
+
93
96
  @func.udf(
94
97
  py_fn=PIL.Image.Image.entropy, return_type=FloatType(), param_types=[ImageType(), ImageType(), JsonType()])
95
- def entropy(self: PIL.Image.Image, mask: PIL.Image.Image, histogram: Dict) -> float:
98
+ def entropy(self: PIL.Image.Image, mask: PIL.Image.Image, extrema: Optional[list] = None) -> float:
96
99
  pass
97
100
 
98
101
  @func.udf(py_fn=PIL.Image.Image.getbands, return_type=JsonType(), param_types=[ImageType()])
@@ -103,8 +106,7 @@ def getbands(self: PIL.Image.Image) -> Tuple[str]:
103
106
  def getbbox(self: PIL.Image.Image) -> Tuple[int, int, int, int]:
104
107
  pass
105
108
 
106
- @func.udf(
107
- py_fn=PIL.Image.Image.getcolors, return_type=JsonType(), param_types=[ImageType(), IntType()])
109
+ @func.udf(py_fn=PIL.Image.Image.getcolors, return_type=JsonType(), param_types=[ImageType(), IntType()])
108
110
  def getcolors(self: PIL.Image.Image, maxcolors: int) -> Tuple[Tuple[int, int, int], int]:
109
111
  pass
110
112
 
@@ -114,37 +116,32 @@ def getextrema(self: PIL.Image.Image) -> Tuple[int, int]:
114
116
 
115
117
  @func.udf(
116
118
  py_fn=PIL.Image.Image.getpalette, return_type=JsonType(), param_types=[ImageType(), StringType()])
117
- def getpalette(self: PIL.Image.Image, mode: str) -> Tuple[int]:
119
+ def getpalette(self: PIL.Image.Image, mode: Optional[str] = None) -> Tuple[int]:
118
120
  pass
119
121
 
120
122
  @func.udf(
121
- py_fn=PIL.Image.Image.getpixel, return_type=JsonType(), param_types=[ImageType(), ArrayType((2,), dtype=IntType())])
122
- def getpixel(self: PIL.Image.Image, xy: Tuple[int, int]) -> Tuple[int]:
123
- pass
123
+ return_type=JsonType(), param_types=[ImageType(), ArrayType((2,), dtype=IntType())])
124
+ def getpixel(self: PIL.Image.Image, xy: tuple[int, int]) -> Tuple[int]:
125
+ # `xy` will be a list; `tuple(xy)` is necessary for pillow 9 compatibility
126
+ return self.getpixel(tuple(xy))
124
127
 
125
- @func.udf(
126
- py_fn=PIL.Image.Image.getprojection, return_type=JsonType(), param_types=[ImageType()])
128
+ @func.udf(py_fn=PIL.Image.Image.getprojection, return_type=JsonType(), param_types=[ImageType()])
127
129
  def getprojection(self: PIL.Image.Image) -> Tuple[int]:
128
130
  pass
129
131
 
130
- @func.udf(
131
- py_fn=PIL.Image.Image.histogram, return_type=JsonType(), param_types=[ImageType(), ImageType(), JsonType()])
132
- def histogram(self: PIL.Image.Image, mask: PIL.Image.Image, histogram: Dict) -> Tuple[int]:
132
+ @func.udf(py_fn=PIL.Image.Image.histogram, return_type=JsonType(), param_types=[ImageType(), ImageType(), JsonType()])
133
+ def histogram(self: PIL.Image.Image, mask: PIL.Image.Image, extrema: Optional[list] = None) -> Tuple[int]:
133
134
  pass
134
135
 
135
136
  @func.udf(
136
137
  py_fn=PIL.Image.Image.quantize, return_type=ImageType(),
137
138
  param_types=[ImageType(), IntType(), IntType(nullable=True), IntType(), IntType(nullable=True), IntType()])
138
139
  def quantize(
139
- self: PIL.Image.Image, colors: int, method: int, kmeans: int, palette: int, dither: int) -> PIL.Image.Image:
140
+ self: PIL.Image.Image, colors: int = 256, method: Optional[int] = None, kmeans: int = 0,
141
+ palette: Optional[int] = None, dither: int = Dither.FLOYDSTEINBERG) -> PIL.Image.Image:
140
142
  pass
141
143
 
142
144
  @func.udf(
143
145
  py_fn=PIL.Image.Image.reduce, return_type=ImageType(), param_types=[ImageType(), IntType(), JsonType()])
144
- def reduce(self: PIL.Image.Image, factor: int, filter: Tuple[int]) -> PIL.Image.Image:
145
- pass
146
-
147
- @func.udf(
148
- py_fn=PIL.Image.Image.transpose, return_type=_caller_return_type, param_types=[ImageType(), IntType()])
149
- def transpose(self: PIL.Image.Image, method: int) -> PIL.Image.Image:
146
+ def reduce(self: PIL.Image.Image, factor: int, box: Optional[Tuple[int]]) -> PIL.Image.Image:
150
147
  pass
@@ -85,6 +85,18 @@ def chat_completions(
85
85
  ).dict()
86
86
 
87
87
 
88
+ _embedding_dimensions_cache = {
89
+ 'togethercomputer/m2-bert-80M-2k-retrieval': 768,
90
+ 'togethercomputer/m2-bert-80M-8k-retrieval': 768,
91
+ 'togethercomputer/m2-bert-80M-32k-retrieval': 768,
92
+ 'WhereIsAI/UAE-Large-V1': 1024,
93
+ 'BAAI/bge-large-en-v1.5': 1024,
94
+ 'BAAI/bge-base-en-v1.5': 768,
95
+ 'sentence-transformers/msmarco-bert-base-dot-v5': 768,
96
+ 'bert-base-uncased': 768,
97
+ }
98
+
99
+
88
100
  @pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
89
101
  def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
90
102
  result = together_client().embeddings.create(input=input, model=model)
@@ -94,6 +106,15 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
94
106
  ]
95
107
 
96
108
 
109
+ @embeddings.conditional_return_type
110
+ def _(model: str) -> pxt.ArrayType:
111
+ if model not in _embedding_dimensions_cache:
112
+ # TODO: find some other way to retrieve a sample
113
+ return pxt.ArrayType((None,), dtype=pxt.FloatType())
114
+ dimensions = _embedding_dimensions_cache[model]
115
+ return pxt.ArrayType((dimensions,), dtype=pxt.FloatType())
116
+
117
+
97
118
  @pxt.udf
98
119
  def image_generations(
99
120
  prompt: str,
@@ -39,3 +39,14 @@ def create_nos_modules() -> List[types.ModuleType]:
39
39
  setattr(sub_module, model_id, pt_func)
40
40
 
41
41
  return new_modules
42
+
43
+
44
+ def resolve_torch_device(device: str) -> str:
45
+ import torch
46
+ if device == 'auto':
47
+ if torch.cuda.is_available():
48
+ return 'cuda'
49
+ if torch.backends.mps.is_available():
50
+ return 'mps'
51
+ return 'cpu'
52
+ return device