pixeltable 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (76) hide show
  1. pixeltable/__init__.py +15 -33
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +1 -1
  4. pixeltable/catalog/column.py +28 -16
  5. pixeltable/catalog/dir.py +2 -2
  6. pixeltable/catalog/insertable_table.py +5 -55
  7. pixeltable/catalog/named_function.py +2 -2
  8. pixeltable/catalog/schema_object.py +2 -7
  9. pixeltable/catalog/table.py +298 -204
  10. pixeltable/catalog/table_version.py +104 -139
  11. pixeltable/catalog/table_version_path.py +22 -4
  12. pixeltable/catalog/view.py +20 -10
  13. pixeltable/dataframe.py +128 -25
  14. pixeltable/env.py +21 -14
  15. pixeltable/exec/exec_context.py +5 -0
  16. pixeltable/exec/exec_node.py +1 -0
  17. pixeltable/exec/in_memory_data_node.py +29 -24
  18. pixeltable/exec/sql_scan_node.py +1 -1
  19. pixeltable/exprs/column_ref.py +13 -8
  20. pixeltable/exprs/data_row.py +4 -0
  21. pixeltable/exprs/expr.py +16 -1
  22. pixeltable/exprs/function_call.py +4 -4
  23. pixeltable/exprs/row_builder.py +29 -20
  24. pixeltable/exprs/similarity_expr.py +4 -3
  25. pixeltable/ext/functions/yolox.py +2 -1
  26. pixeltable/func/__init__.py +1 -0
  27. pixeltable/func/aggregate_function.py +14 -12
  28. pixeltable/func/callable_function.py +8 -6
  29. pixeltable/func/expr_template_function.py +13 -19
  30. pixeltable/func/function.py +3 -6
  31. pixeltable/func/query_template_function.py +84 -0
  32. pixeltable/func/signature.py +68 -23
  33. pixeltable/func/udf.py +13 -10
  34. pixeltable/functions/__init__.py +6 -91
  35. pixeltable/functions/eval.py +26 -14
  36. pixeltable/functions/fireworks.py +25 -23
  37. pixeltable/functions/globals.py +62 -0
  38. pixeltable/functions/huggingface.py +20 -16
  39. pixeltable/functions/image.py +170 -1
  40. pixeltable/functions/openai.py +95 -128
  41. pixeltable/functions/string.py +10 -2
  42. pixeltable/functions/together.py +95 -84
  43. pixeltable/functions/util.py +16 -0
  44. pixeltable/functions/video.py +94 -16
  45. pixeltable/functions/whisper.py +78 -0
  46. pixeltable/globals.py +1 -1
  47. pixeltable/io/__init__.py +10 -0
  48. pixeltable/io/external_store.py +370 -0
  49. pixeltable/io/globals.py +50 -22
  50. pixeltable/{datatransfer → io}/label_studio.py +279 -166
  51. pixeltable/io/parquet.py +1 -1
  52. pixeltable/iterators/__init__.py +9 -0
  53. pixeltable/iterators/string.py +40 -0
  54. pixeltable/metadata/__init__.py +6 -8
  55. pixeltable/metadata/converters/convert_10.py +2 -4
  56. pixeltable/metadata/converters/convert_12.py +7 -2
  57. pixeltable/metadata/converters/convert_13.py +6 -8
  58. pixeltable/metadata/converters/convert_14.py +2 -4
  59. pixeltable/metadata/converters/convert_15.py +40 -25
  60. pixeltable/metadata/converters/convert_16.py +18 -0
  61. pixeltable/metadata/converters/util.py +11 -8
  62. pixeltable/metadata/schema.py +3 -6
  63. pixeltable/plan.py +8 -7
  64. pixeltable/store.py +1 -1
  65. pixeltable/tool/create_test_db_dump.py +145 -54
  66. pixeltable/tool/embed_udf.py +9 -0
  67. pixeltable/type_system.py +1 -2
  68. pixeltable/utils/code.py +34 -0
  69. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/METADATA +2 -2
  70. pixeltable-0.2.9.dist-info/RECORD +131 -0
  71. pixeltable/datatransfer/__init__.py +0 -1
  72. pixeltable/datatransfer/remote.py +0 -113
  73. pixeltable/functions/pil/image.py +0 -147
  74. pixeltable-0.2.7.dist-info/RECORD +0 -126
  75. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/LICENSE +0 -0
  76. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/WHEEL +0 -0
@@ -2,26 +2,31 @@ import base64
2
2
  import io
3
3
  import pathlib
4
4
  import uuid
5
- from typing import Optional, TypeVar, Union, Callable
5
+ from typing import Optional, TypeVar, Union, Callable, TYPE_CHECKING
6
6
 
7
7
  import PIL.Image
8
8
  import numpy as np
9
- import openai
10
9
  import tenacity
11
- from openai._types import NOT_GIVEN, NotGiven
12
10
 
13
11
  import pixeltable as pxt
14
12
  import pixeltable.type_system as ts
15
13
  from pixeltable import env
16
14
  from pixeltable.func import Batch
15
+ from pixeltable.utils.code import local_public_names
16
+
17
+ if TYPE_CHECKING:
18
+ import openai
19
+ from openai._types import NotGiven
17
20
 
18
21
 
19
22
  @env.register_client('openai')
20
- def _(api_key: str) -> openai.OpenAI:
23
+ def _(api_key: str) -> 'openai.OpenAI':
24
+ import openai
25
+
21
26
  return openai.OpenAI(api_key=api_key)
22
27
 
23
28
 
24
- def _openai_client() -> openai.OpenAI:
29
+ def _openai_client() -> 'openai.OpenAI':
25
30
  return env.Env.get().get_client('openai')
26
31
 
27
32
 
@@ -29,80 +34,61 @@ def _openai_client() -> openai.OpenAI:
29
34
  # TODO(aaron-siegel): Right now this hardwires random exponential backoff with defaults suggested
30
35
  # by OpenAI. Should we investigate making this more customizable in the future?
31
36
  def _retry(fn: Callable) -> Callable:
37
+ import openai
38
+
32
39
  return tenacity.retry(
33
40
  retry=tenacity.retry_if_exception_type(openai.RateLimitError),
34
41
  wait=tenacity.wait_random_exponential(multiplier=3, max=180),
35
- stop=tenacity.stop_after_attempt(20)
42
+ stop=tenacity.stop_after_attempt(20),
36
43
  )(fn)
37
44
 
38
45
 
39
46
  #####################################
40
47
  # Audio Endpoints
41
48
 
49
+
42
50
  @pxt.udf(return_type=ts.AudioType())
43
- @_retry
44
51
  def speech(
45
- input: str,
46
- *,
47
- model: str,
48
- voice: str,
49
- response_format: Optional[str] = None,
50
- speed: Optional[float] = None
52
+ input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
51
53
  ) -> str:
52
- content = _openai_client().audio.speech.create(
53
- input=input,
54
- model=model,
55
- voice=voice,
56
- response_format=_opt(response_format),
57
- speed=_opt(speed)
54
+ content = _retry(_openai_client().audio.speech.create)(
55
+ input=input, model=model, voice=voice, response_format=_opt(response_format), speed=_opt(speed)
58
56
  )
59
57
  ext = response_format or 'mp3'
60
- output_filename = str(env.Env.get().tmp_dir / f"{uuid.uuid4()}.{ext}")
58
+ output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
61
59
  content.write_to_file(output_filename)
62
60
  return output_filename
63
61
 
64
62
 
65
63
  @pxt.udf(
66
- param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True),
67
- ts.StringType(nullable=True), ts.FloatType(nullable=True)]
64
+ param_types=[
65
+ ts.AudioType(),
66
+ ts.StringType(),
67
+ ts.StringType(nullable=True),
68
+ ts.StringType(nullable=True),
69
+ ts.FloatType(nullable=True),
70
+ ]
68
71
  )
69
- @_retry
70
72
  def transcriptions(
71
- audio: str,
72
- *,
73
- model: str,
74
- language: Optional[str] = None,
75
- prompt: Optional[str] = None,
76
- temperature: Optional[float] = None
73
+ audio: str,
74
+ *,
75
+ model: str,
76
+ language: Optional[str] = None,
77
+ prompt: Optional[str] = None,
78
+ temperature: Optional[float] = None,
77
79
  ) -> dict:
78
80
  file = pathlib.Path(audio)
79
- transcription = _openai_client().audio.transcriptions.create(
80
- file=file,
81
- model=model,
82
- language=_opt(language),
83
- prompt=_opt(prompt),
84
- temperature=_opt(temperature)
81
+ transcription = _retry(_openai_client().audio.transcriptions.create)(
82
+ file=file, model=model, language=_opt(language), prompt=_opt(prompt), temperature=_opt(temperature)
85
83
  )
86
84
  return transcription.dict()
87
85
 
88
86
 
89
- @pxt.udf(
90
- param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True), ts.FloatType(nullable=True)]
91
- )
92
- @_retry
93
- def translations(
94
- audio: str,
95
- *,
96
- model: str,
97
- prompt: Optional[str] = None,
98
- temperature: Optional[float] = None
99
- ) -> dict:
87
+ @pxt.udf(param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True), ts.FloatType(nullable=True)])
88
+ def translations(audio: str, *, model: str, prompt: Optional[str] = None, temperature: Optional[float] = None) -> dict:
100
89
  file = pathlib.Path(audio)
101
- translation = _openai_client().audio.translations.create(
102
- file=file,
103
- model=model,
104
- prompt=_opt(prompt),
105
- temperature=_opt(temperature)
90
+ translation = _retry(_openai_client().audio.translations.create)(
91
+ file=file, model=model, prompt=_opt(prompt), temperature=_opt(temperature)
106
92
  )
107
93
  return translation.dict()
108
94
 
@@ -110,29 +96,29 @@ def translations(
110
96
  #####################################
111
97
  # Chat Endpoints
112
98
 
99
+
113
100
  @pxt.udf
114
- @_retry
115
101
  def chat_completions(
116
- messages: list,
117
- *,
118
- model: str,
119
- frequency_penalty: Optional[float] = None,
120
- logit_bias: Optional[dict[str, int]] = None,
121
- logprobs: Optional[bool] = None,
122
- top_logprobs: Optional[int] = None,
123
- max_tokens: Optional[int] = None,
124
- n: Optional[int] = None,
125
- presence_penalty: Optional[float] = None,
126
- response_format: Optional[dict] = None,
127
- seed: Optional[int] = None,
128
- stop: Optional[list[str]] = None,
129
- temperature: Optional[float] = None,
130
- top_p: Optional[float] = None,
131
- tools: Optional[list[dict]] = None,
132
- tool_choice: Optional[dict] = None,
133
- user: Optional[str] = None
102
+ messages: list,
103
+ *,
104
+ model: str,
105
+ frequency_penalty: Optional[float] = None,
106
+ logit_bias: Optional[dict[str, int]] = None,
107
+ logprobs: Optional[bool] = None,
108
+ top_logprobs: Optional[int] = None,
109
+ max_tokens: Optional[int] = None,
110
+ n: Optional[int] = None,
111
+ presence_penalty: Optional[float] = None,
112
+ response_format: Optional[dict] = None,
113
+ seed: Optional[int] = None,
114
+ stop: Optional[list[str]] = None,
115
+ temperature: Optional[float] = None,
116
+ top_p: Optional[float] = None,
117
+ tools: Optional[list[dict]] = None,
118
+ tool_choice: Optional[dict] = None,
119
+ user: Optional[str] = None,
134
120
  ) -> dict:
135
- result = _openai_client().chat.completions.create(
121
+ result = _retry(_openai_client().chat.completions.create)(
136
122
  messages=messages,
137
123
  model=model,
138
124
  frequency_penalty=_opt(frequency_penalty),
@@ -149,37 +135,28 @@ def chat_completions(
149
135
  top_p=_opt(top_p),
150
136
  tools=_opt(tools),
151
137
  tool_choice=_opt(tool_choice),
152
- user=_opt(user)
138
+ user=_opt(user),
153
139
  )
154
140
  return result.dict()
155
141
 
156
142
 
157
143
  @pxt.udf
158
- @_retry
159
- def vision(
160
- prompt: str,
161
- image: PIL.Image.Image,
162
- *,
163
- model: str = 'gpt-4-vision-preview'
164
- ) -> str:
144
+ def vision(prompt: str, image: PIL.Image.Image, *, model: str = 'gpt-4-vision-preview') -> str:
165
145
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
166
146
  bytes_arr = io.BytesIO()
167
147
  image.save(bytes_arr, format='png')
168
148
  b64_bytes = base64.b64encode(bytes_arr.getvalue())
169
149
  b64_encoded_image = b64_bytes.decode('utf-8')
170
150
  messages = [
171
- {'role': 'user',
172
- 'content': [
173
- {'type': 'text', 'text': prompt},
174
- {'type': 'image_url', 'image_url': {
175
- 'url': f'data:image/png;base64,{b64_encoded_image}'
176
- }}
177
- ]}
151
+ {
152
+ 'role': 'user',
153
+ 'content': [
154
+ {'type': 'text', 'text': prompt},
155
+ {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{b64_encoded_image}'}},
156
+ ],
157
+ }
178
158
  ]
179
- result = _openai_client().chat.completions.create(
180
- messages=messages,
181
- model=model
182
- )
159
+ result = _retry(_openai_client().chat.completions.create)(messages=messages, model=model)
183
160
  return result.choices[0].message.content
184
161
 
185
162
 
@@ -194,25 +171,13 @@ _embedding_dimensions_cache: dict[str, int] = {
194
171
 
195
172
 
196
173
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
197
- @_retry
198
174
  def embeddings(
199
- input: Batch[str],
200
- *,
201
- model: str,
202
- dimensions: Optional[int] = None,
203
- user: Optional[str] = None
175
+ input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
204
176
  ) -> Batch[np.ndarray]:
205
- result = _openai_client().embeddings.create(
206
- input=input,
207
- model=model,
208
- dimensions=_opt(dimensions),
209
- user=_opt(user),
210
- encoding_format='float'
177
+ result = _retry(_openai_client().embeddings.create)(
178
+ input=input, model=model, dimensions=_opt(dimensions), user=_opt(user), encoding_format='float'
211
179
  )
212
- return [
213
- np.array(data.embedding, dtype=np.float64)
214
- for data in result.data
215
- ]
180
+ return [np.array(data.embedding, dtype=np.float64) for data in result.data]
216
181
 
217
182
 
218
183
  @embeddings.conditional_return_type
@@ -228,26 +193,26 @@ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
228
193
  #####################################
229
194
  # Images Endpoints
230
195
 
196
+
231
197
  @pxt.udf
232
- @_retry
233
198
  def image_generations(
234
- prompt: str,
235
- *,
236
- model: Optional[str] = None,
237
- quality: Optional[str] = None,
238
- size: Optional[str] = None,
239
- style: Optional[str] = None,
240
- user: Optional[str] = None
199
+ prompt: str,
200
+ *,
201
+ model: Optional[str] = None,
202
+ quality: Optional[str] = None,
203
+ size: Optional[str] = None,
204
+ style: Optional[str] = None,
205
+ user: Optional[str] = None,
241
206
  ) -> PIL.Image.Image:
242
207
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
243
- result = _openai_client().images.generate(
208
+ result = _retry(_openai_client().images.generate)(
244
209
  prompt=prompt,
245
210
  model=_opt(model),
246
211
  quality=_opt(quality),
247
212
  size=_opt(size),
248
213
  style=_opt(style),
249
214
  user=_opt(user),
250
- response_format="b64_json"
215
+ response_format='b64_json',
251
216
  )
252
217
  b64_str = result.data[0].b64_json
253
218
  b64_bytes = base64.b64decode(b64_str)
@@ -264,7 +229,7 @@ def _(size: Optional[str] = None) -> ts.ImageType:
264
229
  if x_pos == -1:
265
230
  return ts.ImageType()
266
231
  try:
267
- width, height = int(size[:x_pos]), int(size[x_pos + 1:])
232
+ width, height = int(size[:x_pos]), int(size[x_pos + 1 :])
268
233
  except ValueError:
269
234
  return ts.ImageType()
270
235
  return ts.ImageType(size=(width, height))
@@ -273,22 +238,24 @@ def _(size: Optional[str] = None) -> ts.ImageType:
273
238
  #####################################
274
239
  # Moderations Endpoints
275
240
 
241
+
276
242
  @pxt.udf
277
- @_retry
278
- def moderations(
279
- input: str,
280
- *,
281
- model: Optional[str] = None
282
- ) -> dict:
283
- result = _openai_client().moderations.create(
284
- input=input,
285
- model=_opt(model)
286
- )
243
+ def moderations(input: str, *, model: Optional[str] = None) -> dict:
244
+ result = _retry(_openai_client().moderations.create)(input=input, model=_opt(model))
287
245
  return result.dict()
288
246
 
289
247
 
290
248
  _T = TypeVar('_T')
291
249
 
292
250
 
293
- def _opt(arg: _T) -> Union[_T, NotGiven]:
251
+ def _opt(arg: _T) -> Union[_T, 'NotGiven']:
252
+ from openai._types import NOT_GIVEN
253
+
294
254
  return arg if arg is not None else NOT_GIVEN
255
+
256
+
257
+ __all__ = local_public_names(__name__)
258
+
259
+
260
+ def __dir__():
261
+ return __all__
@@ -1,13 +1,21 @@
1
1
  from typing import Any
2
2
 
3
- from pixeltable.type_system import StringType
4
3
  import pixeltable.func as func
4
+ from pixeltable.type_system import StringType
5
+ from pixeltable.utils.code import local_public_names
5
6
 
6
7
 
7
8
  @func.udf(return_type=StringType(), param_types=[StringType()])
8
9
  def str_format(format_str: str, *args: Any, **kwargs: Any) -> str:
9
- """ Return a formatted version of format_str, using substitutions from args and kwargs:
10
+ """Return a formatted version of format_str, using substitutions from args and kwargs:
10
11
  - {<int>} will be replaced by the corresponding element in args
11
12
  - {<key>} will be replaced by the corresponding value in kwargs
12
13
  """
13
14
  return format_str.format(*args, **kwargs)
15
+
16
+
17
+ __all__ = local_public_names(__name__)
18
+
19
+
20
+ def __dir__():
21
+ return __all__
@@ -1,93 +1,106 @@
1
1
  import base64
2
- import io
3
- from typing import Optional
2
+ from typing import Optional, TYPE_CHECKING
4
3
 
5
4
  import PIL.Image
6
5
  import numpy as np
7
- import together
8
6
 
7
+ import io
9
8
  import pixeltable as pxt
10
9
  from pixeltable import env
11
10
  from pixeltable.func import Batch
11
+ from pixeltable.utils.code import local_public_names
12
+
13
+ if TYPE_CHECKING:
14
+ import together
12
15
 
13
16
 
14
17
  @env.register_client('together')
15
- def _(api_key: str) -> together.Together:
18
+ def _(api_key: str) -> 'together.Together':
19
+ import together
20
+
16
21
  return together.Together(api_key=api_key)
17
22
 
18
23
 
19
- def _together_client() -> together.Together:
24
+ def _together_client() -> 'together.Together':
20
25
  return env.Env.get().get_client('together')
21
26
 
22
27
 
23
28
  @pxt.udf
24
29
  def completions(
25
- prompt: str,
26
- *,
27
- model: str,
28
- max_tokens: Optional[int] = None,
29
- stop: Optional[list] = None,
30
- temperature: Optional[float] = None,
31
- top_p: Optional[float] = None,
32
- top_k: Optional[int] = None,
33
- repetition_penalty: Optional[float] = None,
34
- logprobs: Optional[int] = None,
35
- echo: Optional[bool] = None,
36
- n: Optional[int] = None,
37
- safety_model: Optional[str] = None
30
+ prompt: str,
31
+ *,
32
+ model: str,
33
+ max_tokens: Optional[int] = None,
34
+ stop: Optional[list] = None,
35
+ temperature: Optional[float] = None,
36
+ top_p: Optional[float] = None,
37
+ top_k: Optional[int] = None,
38
+ repetition_penalty: Optional[float] = None,
39
+ logprobs: Optional[int] = None,
40
+ echo: Optional[bool] = None,
41
+ n: Optional[int] = None,
42
+ safety_model: Optional[str] = None,
38
43
  ) -> dict:
39
- return _together_client().completions.create(
40
- prompt=prompt,
41
- model=model,
42
- max_tokens=max_tokens,
43
- stop=stop,
44
- temperature=temperature,
45
- top_p=top_p,
46
- top_k=top_k,
47
- repetition_penalty=repetition_penalty,
48
- logprobs=logprobs,
49
- echo=echo,
50
- n=n,
51
- safety_model=safety_model
52
- ).dict()
44
+ return (
45
+ _together_client()
46
+ .completions.create(
47
+ prompt=prompt,
48
+ model=model,
49
+ max_tokens=max_tokens,
50
+ stop=stop,
51
+ temperature=temperature,
52
+ top_p=top_p,
53
+ top_k=top_k,
54
+ repetition_penalty=repetition_penalty,
55
+ logprobs=logprobs,
56
+ echo=echo,
57
+ n=n,
58
+ safety_model=safety_model,
59
+ )
60
+ .dict()
61
+ )
53
62
 
54
63
 
55
64
  @pxt.udf
56
65
  def chat_completions(
57
- messages: list[dict[str, str]],
58
- *,
59
- model: str,
60
- max_tokens: Optional[int] = None,
61
- stop: Optional[list[str]] = None,
62
- temperature: Optional[float] = None,
63
- top_p: Optional[float] = None,
64
- top_k: Optional[int] = None,
65
- repetition_penalty: Optional[float] = None,
66
- logprobs: Optional[int] = None,
67
- echo: Optional[bool] = None,
68
- n: Optional[int] = None,
69
- safety_model: Optional[str] = None,
70
- response_format: Optional[dict] = None,
71
- tools: Optional[dict] = None,
72
- tool_choice: Optional[dict] = None
66
+ messages: list[dict[str, str]],
67
+ *,
68
+ model: str,
69
+ max_tokens: Optional[int] = None,
70
+ stop: Optional[list[str]] = None,
71
+ temperature: Optional[float] = None,
72
+ top_p: Optional[float] = None,
73
+ top_k: Optional[int] = None,
74
+ repetition_penalty: Optional[float] = None,
75
+ logprobs: Optional[int] = None,
76
+ echo: Optional[bool] = None,
77
+ n: Optional[int] = None,
78
+ safety_model: Optional[str] = None,
79
+ response_format: Optional[dict] = None,
80
+ tools: Optional[dict] = None,
81
+ tool_choice: Optional[dict] = None,
73
82
  ) -> dict:
74
- return _together_client().chat.completions.create(
75
- messages=messages,
76
- model=model,
77
- max_tokens=max_tokens,
78
- stop=stop,
79
- temperature=temperature,
80
- top_p=top_p,
81
- top_k=top_k,
82
- repetition_penalty=repetition_penalty,
83
- logprobs=logprobs,
84
- echo=echo,
85
- n=n,
86
- safety_model=safety_model,
87
- response_format=response_format,
88
- tools=tools,
89
- tool_choice=tool_choice
90
- ).dict()
83
+ return (
84
+ _together_client()
85
+ .chat.completions.create(
86
+ messages=messages,
87
+ model=model,
88
+ max_tokens=max_tokens,
89
+ stop=stop,
90
+ temperature=temperature,
91
+ top_p=top_p,
92
+ top_k=top_k,
93
+ repetition_penalty=repetition_penalty,
94
+ logprobs=logprobs,
95
+ echo=echo,
96
+ n=n,
97
+ safety_model=safety_model,
98
+ response_format=response_format,
99
+ tools=tools,
100
+ tool_choice=tool_choice,
101
+ )
102
+ .dict()
103
+ )
91
104
 
92
105
 
93
106
  _embedding_dimensions_cache = {
@@ -105,10 +118,7 @@ _embedding_dimensions_cache = {
105
118
  @pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
106
119
  def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
107
120
  result = _together_client().embeddings.create(input=input, model=model)
108
- return [
109
- np.array(data.embedding, dtype=np.float64)
110
- for data in result.data
111
- ]
121
+ return [np.array(data.embedding, dtype=np.float64) for data in result.data]
112
122
 
113
123
 
114
124
  @embeddings.conditional_return_type
@@ -122,27 +132,28 @@ def _(model: str) -> pxt.ArrayType:
122
132
 
123
133
  @pxt.udf
124
134
  def image_generations(
125
- prompt: str,
126
- *,
127
- model: str,
128
- steps: Optional[int] = None,
129
- seed: Optional[int] = None,
130
- height: Optional[int] = None,
131
- width: Optional[int] = None,
132
- negative_prompt: Optional[str] = None,
135
+ prompt: str,
136
+ *,
137
+ model: str,
138
+ steps: Optional[int] = None,
139
+ seed: Optional[int] = None,
140
+ height: Optional[int] = None,
141
+ width: Optional[int] = None,
142
+ negative_prompt: Optional[str] = None,
133
143
  ) -> PIL.Image.Image:
134
144
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
135
145
  result = _together_client().images.generate(
136
- prompt=prompt,
137
- model=model,
138
- steps=steps,
139
- seed=seed,
140
- height=height,
141
- width=width,
142
- negative_prompt=negative_prompt
146
+ prompt=prompt, model=model, steps=steps, seed=seed, height=height, width=width, negative_prompt=negative_prompt
143
147
  )
144
148
  b64_str = result.data[0].b64_json
145
149
  b64_bytes = base64.b64decode(b64_str)
146
150
  img = PIL.Image.open(io.BytesIO(b64_bytes))
147
151
  img.load()
148
152
  return img
153
+
154
+
155
+ __all__ = local_public_names(__name__)
156
+
157
+
158
+ def __dir__():
159
+ return __all__
@@ -1,5 +1,9 @@
1
+ import PIL.Image
2
+
3
+
1
4
  def resolve_torch_device(device: str) -> str:
2
5
  import torch
6
+
3
7
  if device == 'auto':
4
8
  if torch.cuda.is_available():
5
9
  return 'cuda'
@@ -7,3 +11,15 @@ def resolve_torch_device(device: str) -> str:
7
11
  return 'mps'
8
12
  return 'cpu'
9
13
  return device
14
+
15
+
16
+ def normalize_image_mode(image: PIL.Image.Image) -> PIL.Image.Image:
17
+ """
18
+ Converts grayscale images to 3-channel for compatibility with models that only work with
19
+ multichannel input.
20
+ """
21
+ if image.mode == '1' or image.mode == 'L':
22
+ return image.convert('RGB')
23
+ if image.mode == 'LA':
24
+ return image.convert('RGBA')
25
+ return image