pixeltable 0.4.0rc2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (59) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +333 -99
  5. pixeltable/catalog/column.py +28 -26
  6. pixeltable/catalog/globals.py +12 -0
  7. pixeltable/catalog/insertable_table.py +8 -8
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +111 -116
  10. pixeltable/catalog/table_version.py +36 -50
  11. pixeltable/catalog/table_version_handle.py +4 -1
  12. pixeltable/catalog/table_version_path.py +28 -4
  13. pixeltable/catalog/view.py +10 -18
  14. pixeltable/config.py +4 -0
  15. pixeltable/dataframe.py +10 -9
  16. pixeltable/env.py +5 -11
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/exec_node.py +2 -0
  19. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  20. pixeltable/exec/sql_node.py +47 -30
  21. pixeltable/exprs/column_property_ref.py +2 -1
  22. pixeltable/exprs/column_ref.py +7 -6
  23. pixeltable/exprs/expr.py +4 -4
  24. pixeltable/func/__init__.py +1 -0
  25. pixeltable/func/mcp.py +74 -0
  26. pixeltable/func/query_template_function.py +4 -2
  27. pixeltable/func/tools.py +12 -2
  28. pixeltable/func/udf.py +2 -2
  29. pixeltable/functions/__init__.py +1 -0
  30. pixeltable/functions/anthropic.py +19 -45
  31. pixeltable/functions/deepseek.py +19 -38
  32. pixeltable/functions/fireworks.py +9 -18
  33. pixeltable/functions/gemini.py +2 -2
  34. pixeltable/functions/groq.py +108 -0
  35. pixeltable/functions/huggingface.py +8 -6
  36. pixeltable/functions/llama_cpp.py +6 -6
  37. pixeltable/functions/mistralai.py +16 -53
  38. pixeltable/functions/ollama.py +1 -1
  39. pixeltable/functions/openai.py +82 -170
  40. pixeltable/functions/replicate.py +2 -2
  41. pixeltable/functions/together.py +22 -80
  42. pixeltable/functions/util.py +6 -1
  43. pixeltable/globals.py +0 -2
  44. pixeltable/io/external_store.py +2 -2
  45. pixeltable/io/label_studio.py +4 -4
  46. pixeltable/io/table_data_conduit.py +1 -1
  47. pixeltable/metadata/__init__.py +1 -1
  48. pixeltable/metadata/converters/convert_37.py +15 -0
  49. pixeltable/metadata/notes.py +1 -0
  50. pixeltable/metadata/schema.py +5 -0
  51. pixeltable/plan.py +37 -121
  52. pixeltable/share/packager.py +2 -2
  53. pixeltable/type_system.py +30 -0
  54. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/METADATA +1 -1
  55. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/RECORD +58 -56
  56. pixeltable/utils/sample.py +0 -25
  57. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/LICENSE +0 -0
  58. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/WHEEL +0 -0
  59. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/entry_points.txt +0 -0
@@ -14,15 +14,14 @@ import math
14
14
  import pathlib
15
15
  import re
16
16
  import uuid
17
- from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar, Union, cast
17
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Type
18
18
 
19
19
  import httpx
20
20
  import numpy as np
21
21
  import PIL
22
22
 
23
23
  import pixeltable as pxt
24
- import pixeltable.type_system as ts
25
- from pixeltable import env, exprs
24
+ from pixeltable import env, exprs, type_system as ts
26
25
  from pixeltable.func import Batch, Tools
27
26
  from pixeltable.utils.code import local_public_names
28
27
 
@@ -171,15 +170,7 @@ def _get_header_info(
171
170
 
172
171
 
173
172
  @pxt.udf
174
- async def speech(
175
- input: str,
176
- *,
177
- model: str,
178
- voice: str,
179
- response_format: Optional[str] = None,
180
- speed: Optional[float] = None,
181
- timeout: Optional[float] = None,
182
- ) -> pxt.Audio:
173
+ async def speech(input: str, *, model: str, voice: str, model_kwargs: Optional[dict[str, Any]] = None) -> pxt.Audio:
183
174
  """
184
175
  Generates audio from the input text.
185
176
 
@@ -199,8 +190,8 @@ async def speech(
199
190
  model: The model to use for speech synthesis.
200
191
  voice: The voice profile to use for speech synthesis. Supported options include:
201
192
  `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.
202
-
203
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/audio/createSpeech>
193
+ model_kwargs: Additional keyword args for the OpenAI `audio/speech` API. For details on the available
194
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createSpeech>
204
195
 
205
196
  Returns:
206
197
  An audio file containing the synthesized speech.
@@ -211,30 +202,18 @@ async def speech(
211
202
 
212
203
  >>> tbl.add_computed_column(audio=speech(tbl.text, model='tts-1', voice='nova'))
213
204
  """
214
- content = await _openai_client().audio.speech.create(
215
- input=input,
216
- model=model,
217
- voice=voice, # type: ignore
218
- response_format=_opt(response_format), # type: ignore
219
- speed=_opt(speed),
220
- timeout=_opt(timeout),
221
- )
222
- ext = response_format or 'mp3'
205
+ if model_kwargs is None:
206
+ model_kwargs = {}
207
+
208
+ content = await _openai_client().audio.speech.create(input=input, model=model, voice=voice, **model_kwargs)
209
+ ext = model_kwargs.get('response_format', 'mp3')
223
210
  output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
224
211
  content.write_to_file(output_filename)
225
212
  return output_filename
226
213
 
227
214
 
228
215
  @pxt.udf
229
- async def transcriptions(
230
- audio: pxt.Audio,
231
- *,
232
- model: str,
233
- language: Optional[str] = None,
234
- prompt: Optional[str] = None,
235
- temperature: Optional[float] = None,
236
- timeout: Optional[float] = None,
237
- ) -> dict:
216
+ async def transcriptions(audio: pxt.Audio, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
238
217
  """
239
218
  Transcribes audio into the input language.
240
219
 
@@ -252,8 +231,8 @@ async def transcriptions(
252
231
  Args:
253
232
  audio: The audio to transcribe.
254
233
  model: The model to use for speech transcription.
255
-
256
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranscription>
234
+ model_kwargs: Additional keyword args for the OpenAI `audio/transcriptions` API. For details on the available
235
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranscription>
257
236
 
258
237
  Returns:
259
238
  A dictionary containing the transcription and other metadata.
@@ -264,27 +243,16 @@ async def transcriptions(
264
243
 
265
244
  >>> tbl.add_computed_column(transcription=transcriptions(tbl.audio, model='whisper-1', language='en'))
266
245
  """
246
+ if model_kwargs is None:
247
+ model_kwargs = {}
248
+
267
249
  file = pathlib.Path(audio)
268
- transcription = await _openai_client().audio.transcriptions.create(
269
- file=file,
270
- model=model,
271
- language=_opt(language),
272
- prompt=_opt(prompt),
273
- temperature=_opt(temperature),
274
- timeout=_opt(timeout),
275
- )
250
+ transcription = await _openai_client().audio.transcriptions.create(file=file, model=model, **model_kwargs)
276
251
  return transcription.dict()
277
252
 
278
253
 
279
254
  @pxt.udf
280
- async def translations(
281
- audio: pxt.Audio,
282
- *,
283
- model: str,
284
- prompt: Optional[str] = None,
285
- temperature: Optional[float] = None,
286
- timeout: Optional[float] = None,
287
- ) -> dict:
255
+ async def translations(audio: pxt.Audio, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
288
256
  """
289
257
  Translates audio into English.
290
258
 
@@ -302,8 +270,8 @@ async def translations(
302
270
  Args:
303
271
  audio: The audio to translate.
304
272
  model: The model to use for speech transcription and translation.
305
-
306
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranslation>
273
+ model_kwargs: Additional keyword args for the OpenAI `audio/translations` API. For details on the available
274
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranslation>
307
275
 
308
276
  Returns:
309
277
  A dictionary containing the translation and other metadata.
@@ -314,10 +282,11 @@ async def translations(
314
282
 
315
283
  >>> tbl.add_computed_column(translation=translations(tbl.audio, model='whisper-1', language='en'))
316
284
  """
285
+ if model_kwargs is None:
286
+ model_kwargs = {}
287
+
317
288
  file = pathlib.Path(audio)
318
- translation = await _openai_client().audio.translations.create(
319
- file=file, model=model, prompt=_opt(prompt), temperature=_opt(temperature), timeout=_opt(timeout)
320
- )
289
+ translation = await _openai_client().audio.translations.create(file=file, model=model, **model_kwargs)
321
290
  return translation.dict()
322
291
 
323
292
 
@@ -353,8 +322,15 @@ def _is_model_family(model: str, family: str) -> bool:
353
322
 
354
323
 
355
324
  def _chat_completions_get_request_resources(
356
- messages: list, model: str, max_completion_tokens: Optional[int], max_tokens: Optional[int], n: Optional[int]
325
+ messages: list, model: str, model_kwargs: Optional[dict[str, Any]]
357
326
  ) -> dict[str, int]:
327
+ if model_kwargs is None:
328
+ model_kwargs = {}
329
+
330
+ max_completion_tokens = model_kwargs.get('max_completion_tokens')
331
+ max_tokens = model_kwargs.get('max_tokens')
332
+ n = model_kwargs.get('n')
333
+
358
334
  completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
359
335
 
360
336
  num_tokens = 0.0
@@ -373,24 +349,9 @@ async def chat_completions(
373
349
  messages: list,
374
350
  *,
375
351
  model: str,
376
- frequency_penalty: Optional[float] = None,
377
- logit_bias: Optional[dict[str, int]] = None,
378
- logprobs: Optional[bool] = None,
379
- top_logprobs: Optional[int] = None,
380
- max_completion_tokens: Optional[int] = None,
381
- max_tokens: Optional[int] = None,
382
- n: Optional[int] = None,
383
- presence_penalty: Optional[float] = None,
384
- reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None,
385
- response_format: Optional[dict] = None,
386
- seed: Optional[int] = None,
387
- stop: Optional[list[str]] = None,
388
- temperature: Optional[float] = None,
389
- tools: Optional[list[dict]] = None,
390
- tool_choice: Optional[dict] = None,
391
- top_p: Optional[float] = None,
392
- user: Optional[str] = None,
393
- timeout: Optional[float] = None,
352
+ model_kwargs: Optional[dict[str, Any]] = None,
353
+ tools: Optional[list[dict[str, Any]]] = None,
354
+ tool_choice: Optional[dict[str, Any]] = None,
394
355
  ) -> dict:
395
356
  """
396
357
  Creates a model response for the given chat conversation.
@@ -409,8 +370,8 @@ async def chat_completions(
409
370
  Args:
410
371
  messages: A list of messages to use for chat completion, as described in the OpenAI API documentation.
411
372
  model: The model to use for chat completion.
412
-
413
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/chat>
373
+ model_kwargs: Additional keyword args for the OpenAI `chat/completions` API. For details on the available
374
+ parameters, see: <https://platform.openai.com/docs/api-reference/chat/create>
414
375
 
415
376
  Returns:
416
377
  A dictionary containing the response and other metadata.
@@ -425,22 +386,23 @@ async def chat_completions(
425
386
  ]
426
387
  tbl.add_computed_column(response=chat_completions(messages, model='gpt-4o-mini'))
427
388
  """
389
+ if model_kwargs is None:
390
+ model_kwargs = {}
391
+
428
392
  if tools is not None:
429
- tools = [{'type': 'function', 'function': tool} for tool in tools]
393
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
430
394
 
431
- tool_choice_: Union[str, dict, None] = None
432
395
  if tool_choice is not None:
433
396
  if tool_choice['auto']:
434
- tool_choice_ = 'auto'
397
+ model_kwargs['tool_choice'] = 'auto'
435
398
  elif tool_choice['required']:
436
- tool_choice_ = 'required'
399
+ model_kwargs['tool_choice'] = 'required'
437
400
  else:
438
401
  assert tool_choice['tool'] is not None
439
- tool_choice_ = {'type': 'function', 'function': {'name': tool_choice['tool']}}
402
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
440
403
 
441
- extra_body: Optional[dict[str, Any]] = None
442
404
  if tool_choice is not None and not tool_choice['parallel_tool_calls']:
443
- extra_body = {'parallel_tool_calls': False}
405
+ model_kwargs['parallel_tool_calls'] = False
444
406
 
445
407
  # make sure the pool info exists prior to making the request
446
408
  resource_pool = _rate_limits_pool(model)
@@ -448,29 +410,8 @@ async def chat_completions(
448
410
  resource_pool, lambda: OpenAIRateLimitsInfo(_chat_completions_get_request_resources)
449
411
  )
450
412
 
451
- # cast(Any, ...): avoid mypy errors
452
413
  result = await _openai_client().chat.completions.with_raw_response.create(
453
- messages=messages,
454
- model=model,
455
- frequency_penalty=_opt(frequency_penalty),
456
- logit_bias=_opt(logit_bias),
457
- logprobs=_opt(logprobs),
458
- top_logprobs=_opt(top_logprobs),
459
- max_completion_tokens=_opt(max_completion_tokens),
460
- max_tokens=_opt(max_tokens),
461
- n=_opt(n),
462
- presence_penalty=_opt(presence_penalty),
463
- reasoning_effort=_opt(reasoning_effort),
464
- response_format=_opt(cast(Any, response_format)),
465
- seed=_opt(seed),
466
- stop=_opt(stop),
467
- temperature=_opt(temperature),
468
- tools=_opt(cast(Any, tools)),
469
- tool_choice=_opt(cast(Any, tool_choice_)),
470
- top_p=_opt(top_p),
471
- user=_opt(user),
472
- timeout=_opt(timeout),
473
- extra_body=extra_body,
414
+ messages=messages, model=model, **model_kwargs
474
415
  )
475
416
 
476
417
  requests_info, tokens_info = _get_header_info(result.headers)
@@ -480,13 +421,15 @@ async def chat_completions(
480
421
 
481
422
 
482
423
  def _vision_get_request_resources(
483
- prompt: str,
484
- image: PIL.Image.Image,
485
- model: str,
486
- max_completion_tokens: Optional[int],
487
- max_tokens: Optional[int],
488
- n: Optional[int],
424
+ prompt: str, image: PIL.Image.Image, model: str, model_kwargs: Optional[dict[str, Any]] = None
489
425
  ) -> dict[str, int]:
426
+ if model_kwargs is None:
427
+ model_kwargs = {}
428
+
429
+ max_completion_tokens = model_kwargs.get('max_completion_tokens')
430
+ max_tokens = model_kwargs.get('max_tokens')
431
+ n = model_kwargs.get('n')
432
+
490
433
  completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
491
434
  prompt_tokens = len(prompt) / 4
492
435
 
@@ -515,14 +458,7 @@ def _vision_get_request_resources(
515
458
 
516
459
  @pxt.udf
517
460
  async def vision(
518
- prompt: str,
519
- image: PIL.Image.Image,
520
- *,
521
- model: str,
522
- max_completion_tokens: Optional[int] = None,
523
- max_tokens: Optional[int] = None,
524
- n: Optional[int] = 1,
525
- timeout: Optional[float] = None,
461
+ prompt: str, image: PIL.Image.Image, *, model: str, model_kwargs: Optional[dict[str, Any]] = None
526
462
  ) -> str:
527
463
  """
528
464
  Analyzes an image with the OpenAI vision capability. This is a convenience function that takes an image and
@@ -552,6 +488,9 @@ async def vision(
552
488
 
553
489
  >>> tbl.add_computed_column(response=vision("What's in this image?", tbl.image, model='gpt-4o-mini'))
554
490
  """
491
+ if model_kwargs is None:
492
+ model_kwargs = {}
493
+
555
494
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
556
495
  bytes_arr = io.BytesIO()
557
496
  image.save(bytes_arr, format='png')
@@ -576,10 +515,7 @@ async def vision(
576
515
  result = await _openai_client().chat.completions.with_raw_response.create(
577
516
  messages=messages, # type: ignore
578
517
  model=model,
579
- max_completion_tokens=_opt(max_completion_tokens),
580
- max_tokens=_opt(max_tokens),
581
- n=_opt(n),
582
- timeout=_opt(timeout),
518
+ **model_kwargs,
583
519
  )
584
520
 
585
521
  requests_info, tokens_info = _get_header_info(result.headers)
@@ -606,12 +542,7 @@ def _embeddings_get_request_resources(input: list[str]) -> dict[str, int]:
606
542
 
607
543
  @pxt.udf(batch_size=32)
608
544
  async def embeddings(
609
- input: Batch[str],
610
- *,
611
- model: str,
612
- dimensions: Optional[int] = None,
613
- user: Optional[str] = None,
614
- timeout: Optional[float] = None,
545
+ input: Batch[str], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
615
546
  ) -> Batch[pxt.Array[(None,), pxt.Float]]:
616
547
  """
617
548
  Creates an embedding vector representing the input text.
@@ -630,10 +561,8 @@ async def embeddings(
630
561
  Args:
631
562
  input: The text to embed.
632
563
  model: The model to use for the embedding.
633
- dimensions: The vector length of the embedding. If not specified, Pixeltable will use
634
- a default value based on the model.
635
-
636
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
564
+ model_kwargs: Additional keyword args for the OpenAI `embeddings` API. For details on the available
565
+ parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
637
566
 
638
567
  Returns:
639
568
  An array representing the application of the given embedding to `input`.
@@ -648,18 +577,16 @@ async def embeddings(
648
577
 
649
578
  >>> tbl.add_embedding_index(embedding=embeddings.using(model='text-embedding-3-small'))
650
579
  """
580
+ if model_kwargs is None:
581
+ model_kwargs = {}
582
+
651
583
  _logger.debug(f'embeddings: batch_size={len(input)}')
652
584
  resource_pool = _rate_limits_pool(model)
653
585
  rate_limits_info = env.Env.get().get_resource_pool_info(
654
586
  resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources)
655
587
  )
656
588
  result = await _openai_client().embeddings.with_raw_response.create(
657
- input=input,
658
- model=model,
659
- dimensions=_opt(dimensions),
660
- user=_opt(user),
661
- encoding_format='float',
662
- timeout=_opt(timeout),
589
+ input=input, model=model, encoding_format='float', **model_kwargs
663
590
  )
664
591
  requests_info, tokens_info = _get_header_info(result.headers)
665
592
  rate_limits_info.record(requests=requests_info, tokens=tokens_info)
@@ -667,7 +594,10 @@ async def embeddings(
667
594
 
668
595
 
669
596
  @embeddings.conditional_return_type
670
- def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
597
+ def _(model: str, model_kwargs: Optional[dict[str, Any]] = None) -> ts.ArrayType:
598
+ dimensions: Optional[int] = None
599
+ if model_kwargs is not None:
600
+ dimensions = model_kwargs.get('dimensions')
671
601
  if dimensions is None:
672
602
  if model not in _embedding_dimensions_cache:
673
603
  # TODO: find some other way to retrieve a sample
@@ -682,14 +612,7 @@ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
682
612
 
683
613
  @pxt.udf
684
614
  async def image_generations(
685
- prompt: str,
686
- *,
687
- model: str = 'dall-e-2',
688
- quality: Optional[str] = None,
689
- size: Optional[str] = None,
690
- style: Optional[str] = None,
691
- user: Optional[str] = None,
692
- timeout: Optional[float] = None,
615
+ prompt: str, *, model: str = 'dall-e-2', model_kwargs: Optional[dict[str, Any]] = None
693
616
  ) -> PIL.Image.Image:
694
617
  """
695
618
  Creates an image given a prompt.
@@ -708,8 +631,8 @@ async def image_generations(
708
631
  Args:
709
632
  prompt: Prompt for the image.
710
633
  model: The model to use for the generations.
711
-
712
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/images/create>
634
+ model_kwargs: Additional keyword args for the OpenAI `images/generations` API. For details on the available
635
+ parameters, see: <https://platform.openai.com/docs/api-reference/images/create>
713
636
 
714
637
  Returns:
715
638
  The generated image.
@@ -720,16 +643,12 @@ async def image_generations(
720
643
 
721
644
  >>> tbl.add_computed_column(gen_image=image_generations(tbl.text, model='dall-e-2'))
722
645
  """
646
+ if model_kwargs is None:
647
+ model_kwargs = {}
648
+
723
649
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
724
650
  result = await _openai_client().images.generate(
725
- prompt=prompt,
726
- model=_opt(model),
727
- quality=_opt(quality), # type: ignore
728
- size=_opt(size), # type: ignore
729
- style=_opt(style), # type: ignore
730
- user=_opt(user),
731
- response_format='b64_json',
732
- timeout=_opt(timeout),
651
+ prompt=prompt, model=model, response_format='b64_json', **model_kwargs
733
652
  )
734
653
  b64_str = result.data[0].b64_json
735
654
  b64_bytes = base64.b64decode(b64_str)
@@ -739,9 +658,11 @@ async def image_generations(
739
658
 
740
659
 
741
660
  @image_generations.conditional_return_type
742
- def _(size: Optional[str] = None) -> ts.ImageType:
743
- if size is None:
661
+ def _(model_kwargs: Optional[dict[str, Any]] = None) -> ts.ImageType:
662
+ if model_kwargs is None or 'size' not in model_kwargs:
663
+ # default size is 1024x1024
744
664
  return ts.ImageType(size=(1024, 1024))
665
+ size = model_kwargs['size']
745
666
  x_pos = size.find('x')
746
667
  if x_pos == -1:
747
668
  return ts.ImageType()
@@ -787,7 +708,7 @@ async def moderations(input: str, *, model: str = 'omni-moderation-latest') -> d
787
708
 
788
709
  >>> tbl.add_computed_column(moderations=moderations(tbl.text, model='text-moderation-stable'))
789
710
  """
790
- result = await _openai_client().moderations.create(input=input, model=_opt(model))
711
+ result = await _openai_client().moderations.create(input=input, model=model)
791
712
  return result.dict()
792
713
 
793
714
 
@@ -826,15 +747,6 @@ def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
826
747
  return pxt_tool_calls
827
748
 
828
749
 
829
- _T = TypeVar('_T')
830
-
831
-
832
- def _opt(arg: _T) -> Union[_T, 'openai.NotGiven']:
833
- import openai
834
-
835
- return arg if arg is not None else openai.NOT_GIVEN
836
-
837
-
838
750
  __all__ = local_public_names(__name__)
839
751
 
840
752
 
@@ -12,7 +12,7 @@ from pixeltable.env import Env, register_client
12
12
  from pixeltable.utils.code import local_public_names
13
13
 
14
14
  if TYPE_CHECKING:
15
- import replicate # type: ignore[import-untyped]
15
+ import replicate
16
16
 
17
17
 
18
18
  @register_client('replicate')
@@ -27,7 +27,7 @@ def _replicate_client() -> 'replicate.Client':
27
27
 
28
28
 
29
29
  @pxt.udf(resource_pool='request-rate:replicate')
30
- async def run(input: dict[str, Any], *, ref: str) -> dict[str, Any]:
30
+ async def run(input: dict[str, Any], *, ref: str) -> pxt.Json:
31
31
  """
32
32
  Run a model on Replicate.
33
33
 
@@ -7,7 +7,7 @@ the [Working with Together AI](https://pixeltable.readme.io/docs/together-ai) tu
7
7
 
8
8
  import base64
9
9
  import io
10
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar
10
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
11
11
 
12
12
  import numpy as np
13
13
  import PIL.Image
@@ -50,21 +50,7 @@ def _retry(fn: Callable[..., T]) -> Callable[..., T]:
50
50
 
51
51
 
52
52
  @pxt.udf(resource_pool='request-rate:together:chat')
53
- async def completions(
54
- prompt: str,
55
- *,
56
- model: str,
57
- max_tokens: Optional[int] = None,
58
- stop: Optional[list] = None,
59
- temperature: Optional[float] = None,
60
- top_p: Optional[float] = None,
61
- top_k: Optional[int] = None,
62
- repetition_penalty: Optional[float] = None,
63
- logprobs: Optional[int] = None,
64
- echo: Optional[bool] = None,
65
- n: Optional[int] = None,
66
- safety_model: Optional[str] = None,
67
- ) -> dict:
53
+ async def completions(prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
68
54
  """
69
55
  Generate completions based on a given prompt using a specified model.
70
56
 
@@ -82,8 +68,8 @@ async def completions(
82
68
  Args:
83
69
  prompt: A string providing context for the model to complete.
84
70
  model: The name of the model to query.
85
-
86
- For details on the other parameters, see: <https://docs.together.ai/reference/completions-1>
71
+ model_kwargs: Additional keyword arguments for the Together `completions` API.
72
+ For details on the available parameters, see: <https://docs.together.ai/reference/completions-1>
87
73
 
88
74
  Returns:
89
75
  A dictionary containing the response and other metadata.
@@ -94,41 +80,16 @@ async def completions(
94
80
 
95
81
  >>> tbl.add_computed_column(response=completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1'))
96
82
  """
97
- result = await _together_client().completions.create(
98
- prompt=prompt,
99
- model=model,
100
- max_tokens=max_tokens,
101
- stop=stop,
102
- temperature=temperature,
103
- top_p=top_p,
104
- top_k=top_k,
105
- repetition_penalty=repetition_penalty,
106
- logprobs=logprobs,
107
- echo=echo,
108
- n=n,
109
- safety_model=safety_model,
110
- )
83
+ if model_kwargs is None:
84
+ model_kwargs = {}
85
+
86
+ result = await _together_client().completions.create(prompt=prompt, model=model, **model_kwargs)
111
87
  return result.dict()
112
88
 
113
89
 
114
90
  @pxt.udf(resource_pool='request-rate:together:chat')
115
91
  async def chat_completions(
116
- messages: list[dict[str, str]],
117
- *,
118
- model: str,
119
- max_tokens: Optional[int] = None,
120
- stop: Optional[list[str]] = None,
121
- temperature: Optional[float] = None,
122
- top_p: Optional[float] = None,
123
- top_k: Optional[int] = None,
124
- repetition_penalty: Optional[float] = None,
125
- logprobs: Optional[int] = None,
126
- echo: Optional[bool] = None,
127
- n: Optional[int] = None,
128
- safety_model: Optional[str] = None,
129
- response_format: Optional[dict] = None,
130
- tools: Optional[dict] = None,
131
- tool_choice: Optional[dict] = None,
92
+ messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
132
93
  ) -> dict:
133
94
  """
134
95
  Generate chat completions based on a given prompt using a specified model.
@@ -147,8 +108,8 @@ async def chat_completions(
147
108
  Args:
148
109
  messages: A list of messages comprising the conversation so far.
149
110
  model: The name of the model to query.
150
-
151
- For details on the other parameters, see: <https://docs.together.ai/reference/chat-completions-1>
111
+ model_kwargs: Additional keyword arguments for the Together `chat/completions` API.
112
+ For details on the available parameters, see: <https://docs.together.ai/reference/chat-completions-1>
152
113
 
153
114
  Returns:
154
115
  A dictionary containing the response and other metadata.
@@ -160,23 +121,10 @@ async def chat_completions(
160
121
  >>> messages = [{'role': 'user', 'content': tbl.prompt}]
161
122
  ... tbl.add_computed_column(response=chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1'))
162
123
  """
163
- result = await _together_client().chat.completions.create(
164
- messages=messages,
165
- model=model,
166
- max_tokens=max_tokens,
167
- stop=stop,
168
- temperature=temperature,
169
- top_p=top_p,
170
- top_k=top_k,
171
- repetition_penalty=repetition_penalty,
172
- logprobs=logprobs,
173
- echo=echo,
174
- n=n,
175
- safety_model=safety_model,
176
- response_format=response_format,
177
- tools=tools,
178
- tool_choice=tool_choice,
179
- )
124
+ if model_kwargs is None:
125
+ model_kwargs = {}
126
+
127
+ result = await _together_client().chat.completions.create(messages=messages, model=model, **model_kwargs)
180
128
  return result.dict()
181
129
 
182
130
 
@@ -236,14 +184,7 @@ def _(model: str) -> ts.ArrayType:
236
184
 
237
185
  @pxt.udf(resource_pool='request-rate:together:images')
238
186
  async def image_generations(
239
- prompt: str,
240
- *,
241
- model: str,
242
- steps: Optional[int] = None,
243
- seed: Optional[int] = None,
244
- height: Optional[int] = None,
245
- width: Optional[int] = None,
246
- negative_prompt: Optional[str] = None,
187
+ prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None
247
188
  ) -> PIL.Image.Image:
248
189
  """
249
190
  Generate images based on a given prompt using a specified model.
@@ -262,8 +203,8 @@ async def image_generations(
262
203
  Args:
263
204
  prompt: A description of the desired images.
264
205
  model: The model to use for image generation.
265
-
266
- For details on the other parameters, see: <https://docs.together.ai/reference/post_images-generations>
206
+ model_kwargs: Additional keyword args for the Together `images/generations` API.
207
+ For details on the available parameters, see: <https://docs.together.ai/reference/post_images-generations>
267
208
 
268
209
  Returns:
269
210
  The generated image.
@@ -276,9 +217,10 @@ async def image_generations(
276
217
  ... response=image_generations(tbl.prompt, model='stabilityai/stable-diffusion-xl-base-1.0')
277
218
  ... )
278
219
  """
279
- result = await _together_client().images.generate(
280
- prompt=prompt, model=model, steps=steps, seed=seed, height=height, width=width, negative_prompt=negative_prompt
281
- )
220
+ if model_kwargs is None:
221
+ model_kwargs = {}
222
+
223
+ result = await _together_client().images.generate(prompt=prompt, model=model, **model_kwargs)
282
224
  if result.data[0].b64_json is not None:
283
225
  b64_bytes = base64.b64decode(result.data[0].b64_json)
284
226
  img = PIL.Image.open(io.BytesIO(b64_bytes))
@@ -1,5 +1,6 @@
1
1
  import PIL.Image
2
2
 
3
+ from pixeltable.config import Config
3
4
  from pixeltable.env import Env
4
5
 
5
6
 
@@ -7,10 +8,14 @@ def resolve_torch_device(device: str, allow_mps: bool = True) -> str:
7
8
  Env.get().require_package('torch')
8
9
  import torch
9
10
 
11
+ mps_enabled = Config.get().get_bool_value('enable_mps')
12
+ if mps_enabled is None:
13
+ mps_enabled = True # Default to True if not set in config
14
+
10
15
  if device == 'auto':
11
16
  if torch.cuda.is_available():
12
17
  return 'cuda'
13
- if allow_mps and torch.backends.mps.is_available():
18
+ if mps_enabled and allow_mps and torch.backends.mps.is_available():
14
19
  return 'mps'
15
20
  return 'cpu'
16
21
  return device