pixeltable 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (78) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +559 -134
  5. pixeltable/catalog/column.py +36 -32
  6. pixeltable/catalog/dir.py +1 -2
  7. pixeltable/catalog/globals.py +12 -0
  8. pixeltable/catalog/insertable_table.py +30 -25
  9. pixeltable/catalog/schema_object.py +9 -6
  10. pixeltable/catalog/table.py +334 -267
  11. pixeltable/catalog/table_version.py +358 -241
  12. pixeltable/catalog/table_version_handle.py +18 -2
  13. pixeltable/catalog/table_version_path.py +86 -16
  14. pixeltable/catalog/view.py +47 -23
  15. pixeltable/dataframe.py +198 -19
  16. pixeltable/env.py +6 -4
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/__init__.py +1 -1
  19. pixeltable/exec/exec_node.py +2 -0
  20. pixeltable/exec/expr_eval/evaluators.py +4 -1
  21. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  22. pixeltable/exec/in_memory_data_node.py +1 -1
  23. pixeltable/exec/sql_node.py +188 -22
  24. pixeltable/exprs/column_property_ref.py +16 -6
  25. pixeltable/exprs/column_ref.py +33 -11
  26. pixeltable/exprs/comparison.py +1 -1
  27. pixeltable/exprs/data_row.py +5 -3
  28. pixeltable/exprs/expr.py +11 -4
  29. pixeltable/exprs/literal.py +2 -0
  30. pixeltable/exprs/row_builder.py +4 -6
  31. pixeltable/exprs/rowid_ref.py +8 -0
  32. pixeltable/exprs/similarity_expr.py +1 -0
  33. pixeltable/func/__init__.py +1 -0
  34. pixeltable/func/mcp.py +74 -0
  35. pixeltable/func/query_template_function.py +5 -3
  36. pixeltable/func/tools.py +12 -2
  37. pixeltable/func/udf.py +2 -2
  38. pixeltable/functions/__init__.py +1 -0
  39. pixeltable/functions/anthropic.py +19 -45
  40. pixeltable/functions/deepseek.py +19 -38
  41. pixeltable/functions/fireworks.py +9 -18
  42. pixeltable/functions/gemini.py +2 -3
  43. pixeltable/functions/groq.py +108 -0
  44. pixeltable/functions/llama_cpp.py +6 -6
  45. pixeltable/functions/mistralai.py +16 -53
  46. pixeltable/functions/ollama.py +1 -1
  47. pixeltable/functions/openai.py +82 -165
  48. pixeltable/functions/string.py +212 -58
  49. pixeltable/functions/together.py +22 -80
  50. pixeltable/globals.py +10 -4
  51. pixeltable/index/base.py +5 -0
  52. pixeltable/index/btree.py +5 -0
  53. pixeltable/index/embedding_index.py +5 -0
  54. pixeltable/io/external_store.py +10 -31
  55. pixeltable/io/label_studio.py +5 -5
  56. pixeltable/io/parquet.py +2 -2
  57. pixeltable/io/table_data_conduit.py +1 -32
  58. pixeltable/metadata/__init__.py +11 -2
  59. pixeltable/metadata/converters/convert_13.py +2 -2
  60. pixeltable/metadata/converters/convert_30.py +6 -11
  61. pixeltable/metadata/converters/convert_35.py +9 -0
  62. pixeltable/metadata/converters/convert_36.py +38 -0
  63. pixeltable/metadata/converters/convert_37.py +15 -0
  64. pixeltable/metadata/converters/util.py +3 -9
  65. pixeltable/metadata/notes.py +3 -0
  66. pixeltable/metadata/schema.py +13 -1
  67. pixeltable/plan.py +135 -12
  68. pixeltable/share/packager.py +138 -14
  69. pixeltable/share/publish.py +2 -2
  70. pixeltable/store.py +19 -13
  71. pixeltable/type_system.py +30 -0
  72. pixeltable/utils/dbms.py +1 -1
  73. pixeltable/utils/formatter.py +64 -42
  74. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/METADATA +2 -1
  75. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/RECORD +78 -73
  76. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/LICENSE +0 -0
  77. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/WHEEL +0 -0
  78. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -50,7 +50,7 @@ def generate(
50
50
  template: Prompt template to use.
51
51
  context: The context parameter returned from a previous call to `generate()`.
52
52
  raw: If `True`, no formatting will be applied to the prompt.
53
- options: Additional options to pass to the `chat` call, such as `max_tokens`, `temperature`, `top_p`, and
53
+ options: Additional options for the Ollama `chat` call, such as `max_tokens`, `temperature`, `top_p`, and
54
54
  `top_k`. For details, see the
55
55
  [Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
56
56
  section of the Ollama documentation.
@@ -14,15 +14,14 @@ import math
14
14
  import pathlib
15
15
  import re
16
16
  import uuid
17
- from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar, Union, cast
17
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Type
18
18
 
19
19
  import httpx
20
20
  import numpy as np
21
21
  import PIL
22
22
 
23
23
  import pixeltable as pxt
24
- import pixeltable.type_system as ts
25
- from pixeltable import env, exprs
24
+ from pixeltable import env, exprs, type_system as ts
26
25
  from pixeltable.func import Batch, Tools
27
26
  from pixeltable.utils.code import local_public_names
28
27
 
@@ -171,15 +170,7 @@ def _get_header_info(
171
170
 
172
171
 
173
172
  @pxt.udf
174
- async def speech(
175
- input: str,
176
- *,
177
- model: str,
178
- voice: str,
179
- response_format: Optional[str] = None,
180
- speed: Optional[float] = None,
181
- timeout: Optional[float] = None,
182
- ) -> pxt.Audio:
173
+ async def speech(input: str, *, model: str, voice: str, model_kwargs: Optional[dict[str, Any]] = None) -> pxt.Audio:
183
174
  """
184
175
  Generates audio from the input text.
185
176
 
@@ -199,8 +190,8 @@ async def speech(
199
190
  model: The model to use for speech synthesis.
200
191
  voice: The voice profile to use for speech synthesis. Supported options include:
201
192
  `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.
202
-
203
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/audio/createSpeech>
193
+ model_kwargs: Additional keyword args for the OpenAI `audio/speech` API. For details on the available
194
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createSpeech>
204
195
 
205
196
  Returns:
206
197
  An audio file containing the synthesized speech.
@@ -211,30 +202,23 @@ async def speech(
211
202
 
212
203
  >>> tbl.add_computed_column(audio=speech(tbl.text, model='tts-1', voice='nova'))
213
204
  """
205
+ if model_kwargs is None:
206
+ model_kwargs = {}
207
+
214
208
  content = await _openai_client().audio.speech.create(
215
209
  input=input,
216
210
  model=model,
217
211
  voice=voice, # type: ignore
218
- response_format=_opt(response_format), # type: ignore
219
- speed=_opt(speed),
220
- timeout=_opt(timeout),
212
+ **model_kwargs,
221
213
  )
222
- ext = response_format or 'mp3'
214
+ ext = model_kwargs.get('response_format', 'mp3')
223
215
  output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
224
216
  content.write_to_file(output_filename)
225
217
  return output_filename
226
218
 
227
219
 
228
220
  @pxt.udf
229
- async def transcriptions(
230
- audio: pxt.Audio,
231
- *,
232
- model: str,
233
- language: Optional[str] = None,
234
- prompt: Optional[str] = None,
235
- temperature: Optional[float] = None,
236
- timeout: Optional[float] = None,
237
- ) -> dict:
221
+ async def transcriptions(audio: pxt.Audio, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
238
222
  """
239
223
  Transcribes audio into the input language.
240
224
 
@@ -252,8 +236,8 @@ async def transcriptions(
252
236
  Args:
253
237
  audio: The audio to transcribe.
254
238
  model: The model to use for speech transcription.
255
-
256
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranscription>
239
+ model_kwargs: Additional keyword args for the OpenAI `audio/transcriptions` API. For details on the available
240
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranscription>
257
241
 
258
242
  Returns:
259
243
  A dictionary containing the transcription and other metadata.
@@ -264,27 +248,16 @@ async def transcriptions(
264
248
 
265
249
  >>> tbl.add_computed_column(transcription=transcriptions(tbl.audio, model='whisper-1', language='en'))
266
250
  """
251
+ if model_kwargs is None:
252
+ model_kwargs = {}
253
+
267
254
  file = pathlib.Path(audio)
268
- transcription = await _openai_client().audio.transcriptions.create(
269
- file=file,
270
- model=model,
271
- language=_opt(language),
272
- prompt=_opt(prompt),
273
- temperature=_opt(temperature),
274
- timeout=_opt(timeout),
275
- )
255
+ transcription = await _openai_client().audio.transcriptions.create(file=file, model=model, **model_kwargs)
276
256
  return transcription.dict()
277
257
 
278
258
 
279
259
  @pxt.udf
280
- async def translations(
281
- audio: pxt.Audio,
282
- *,
283
- model: str,
284
- prompt: Optional[str] = None,
285
- temperature: Optional[float] = None,
286
- timeout: Optional[float] = None,
287
- ) -> dict:
260
+ async def translations(audio: pxt.Audio, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
288
261
  """
289
262
  Translates audio into English.
290
263
 
@@ -302,8 +275,8 @@ async def translations(
302
275
  Args:
303
276
  audio: The audio to translate.
304
277
  model: The model to use for speech transcription and translation.
305
-
306
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranslation>
278
+ model_kwargs: Additional keyword args for the OpenAI `audio/translations` API. For details on the available
279
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranslation>
307
280
 
308
281
  Returns:
309
282
  A dictionary containing the translation and other metadata.
@@ -314,10 +287,11 @@ async def translations(
314
287
 
315
288
  >>> tbl.add_computed_column(translation=translations(tbl.audio, model='whisper-1', language='en'))
316
289
  """
290
+ if model_kwargs is None:
291
+ model_kwargs = {}
292
+
317
293
  file = pathlib.Path(audio)
318
- translation = await _openai_client().audio.translations.create(
319
- file=file, model=model, prompt=_opt(prompt), temperature=_opt(temperature), timeout=_opt(timeout)
320
- )
294
+ translation = await _openai_client().audio.translations.create(file=file, model=model, **model_kwargs)
321
295
  return translation.dict()
322
296
 
323
297
 
@@ -353,8 +327,15 @@ def _is_model_family(model: str, family: str) -> bool:
353
327
 
354
328
 
355
329
  def _chat_completions_get_request_resources(
356
- messages: list, model: str, max_completion_tokens: Optional[int], max_tokens: Optional[int], n: Optional[int]
330
+ messages: list, model: str, model_kwargs: Optional[dict[str, Any]]
357
331
  ) -> dict[str, int]:
332
+ if model_kwargs is None:
333
+ model_kwargs = {}
334
+
335
+ max_completion_tokens = model_kwargs.get('max_completion_tokens')
336
+ max_tokens = model_kwargs.get('max_tokens')
337
+ n = model_kwargs.get('n')
338
+
358
339
  completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
359
340
 
360
341
  num_tokens = 0.0
@@ -373,24 +354,9 @@ async def chat_completions(
373
354
  messages: list,
374
355
  *,
375
356
  model: str,
376
- frequency_penalty: Optional[float] = None,
377
- logit_bias: Optional[dict[str, int]] = None,
378
- logprobs: Optional[bool] = None,
379
- top_logprobs: Optional[int] = None,
380
- max_completion_tokens: Optional[int] = None,
381
- max_tokens: Optional[int] = None,
382
- n: Optional[int] = None,
383
- presence_penalty: Optional[float] = None,
384
- reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None,
385
- response_format: Optional[dict] = None,
386
- seed: Optional[int] = None,
387
- stop: Optional[list[str]] = None,
388
- temperature: Optional[float] = None,
389
- tools: Optional[list[dict]] = None,
390
- tool_choice: Optional[dict] = None,
391
- top_p: Optional[float] = None,
392
- user: Optional[str] = None,
393
- timeout: Optional[float] = None,
357
+ model_kwargs: Optional[dict[str, Any]] = None,
358
+ tools: Optional[list[dict[str, Any]]] = None,
359
+ tool_choice: Optional[dict[str, Any]] = None,
394
360
  ) -> dict:
395
361
  """
396
362
  Creates a model response for the given chat conversation.
@@ -409,8 +375,8 @@ async def chat_completions(
409
375
  Args:
410
376
  messages: A list of messages to use for chat completion, as described in the OpenAI API documentation.
411
377
  model: The model to use for chat completion.
412
-
413
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/chat>
378
+ model_kwargs: Additional keyword args for the OpenAI `chat/completions` API. For details on the available
379
+ parameters, see: <https://platform.openai.com/docs/api-reference/chat/create>
414
380
 
415
381
  Returns:
416
382
  A dictionary containing the response and other metadata.
@@ -425,22 +391,23 @@ async def chat_completions(
425
391
  ]
426
392
  tbl.add_computed_column(response=chat_completions(messages, model='gpt-4o-mini'))
427
393
  """
394
+ if model_kwargs is None:
395
+ model_kwargs = {}
396
+
428
397
  if tools is not None:
429
- tools = [{'type': 'function', 'function': tool} for tool in tools]
398
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
430
399
 
431
- tool_choice_: Union[str, dict, None] = None
432
400
  if tool_choice is not None:
433
401
  if tool_choice['auto']:
434
- tool_choice_ = 'auto'
402
+ model_kwargs['tool_choice'] = 'auto'
435
403
  elif tool_choice['required']:
436
- tool_choice_ = 'required'
404
+ model_kwargs['tool_choice'] = 'required'
437
405
  else:
438
406
  assert tool_choice['tool'] is not None
439
- tool_choice_ = {'type': 'function', 'function': {'name': tool_choice['tool']}}
407
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
440
408
 
441
- extra_body: Optional[dict[str, Any]] = None
442
409
  if tool_choice is not None and not tool_choice['parallel_tool_calls']:
443
- extra_body = {'parallel_tool_calls': False}
410
+ model_kwargs['parallel_tool_calls'] = False
444
411
 
445
412
  # make sure the pool info exists prior to making the request
446
413
  resource_pool = _rate_limits_pool(model)
@@ -448,29 +415,8 @@ async def chat_completions(
448
415
  resource_pool, lambda: OpenAIRateLimitsInfo(_chat_completions_get_request_resources)
449
416
  )
450
417
 
451
- # cast(Any, ...): avoid mypy errors
452
418
  result = await _openai_client().chat.completions.with_raw_response.create(
453
- messages=messages,
454
- model=model,
455
- frequency_penalty=_opt(frequency_penalty),
456
- logit_bias=_opt(logit_bias),
457
- logprobs=_opt(logprobs),
458
- top_logprobs=_opt(top_logprobs),
459
- max_completion_tokens=_opt(max_completion_tokens),
460
- max_tokens=_opt(max_tokens),
461
- n=_opt(n),
462
- presence_penalty=_opt(presence_penalty),
463
- reasoning_effort=_opt(reasoning_effort),
464
- response_format=_opt(cast(Any, response_format)),
465
- seed=_opt(seed),
466
- stop=_opt(stop),
467
- temperature=_opt(temperature),
468
- tools=_opt(cast(Any, tools)),
469
- tool_choice=_opt(cast(Any, tool_choice_)),
470
- top_p=_opt(top_p),
471
- user=_opt(user),
472
- timeout=_opt(timeout),
473
- extra_body=extra_body,
419
+ messages=messages, model=model, **model_kwargs
474
420
  )
475
421
 
476
422
  requests_info, tokens_info = _get_header_info(result.headers)
@@ -480,13 +426,15 @@ async def chat_completions(
480
426
 
481
427
 
482
428
  def _vision_get_request_resources(
483
- prompt: str,
484
- image: PIL.Image.Image,
485
- model: str,
486
- max_completion_tokens: Optional[int],
487
- max_tokens: Optional[int],
488
- n: Optional[int],
429
+ prompt: str, image: PIL.Image.Image, model: str, model_kwargs: Optional[dict[str, Any]] = None
489
430
  ) -> dict[str, int]:
431
+ if model_kwargs is None:
432
+ model_kwargs = {}
433
+
434
+ max_completion_tokens = model_kwargs.get('max_completion_tokens')
435
+ max_tokens = model_kwargs.get('max_tokens')
436
+ n = model_kwargs.get('n')
437
+
490
438
  completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
491
439
  prompt_tokens = len(prompt) / 4
492
440
 
@@ -515,14 +463,7 @@ def _vision_get_request_resources(
515
463
 
516
464
  @pxt.udf
517
465
  async def vision(
518
- prompt: str,
519
- image: PIL.Image.Image,
520
- *,
521
- model: str,
522
- max_completion_tokens: Optional[int] = None,
523
- max_tokens: Optional[int] = None,
524
- n: Optional[int] = 1,
525
- timeout: Optional[float] = None,
466
+ prompt: str, image: PIL.Image.Image, *, model: str, model_kwargs: Optional[dict[str, Any]] = None
526
467
  ) -> str:
527
468
  """
528
469
  Analyzes an image with the OpenAI vision capability. This is a convenience function that takes an image and
@@ -552,6 +493,9 @@ async def vision(
552
493
 
553
494
  >>> tbl.add_computed_column(response=vision("What's in this image?", tbl.image, model='gpt-4o-mini'))
554
495
  """
496
+ if model_kwargs is None:
497
+ model_kwargs = {}
498
+
555
499
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
556
500
  bytes_arr = io.BytesIO()
557
501
  image.save(bytes_arr, format='png')
@@ -576,10 +520,7 @@ async def vision(
576
520
  result = await _openai_client().chat.completions.with_raw_response.create(
577
521
  messages=messages, # type: ignore
578
522
  model=model,
579
- max_completion_tokens=_opt(max_completion_tokens),
580
- max_tokens=_opt(max_tokens),
581
- n=_opt(n),
582
- timeout=_opt(timeout),
523
+ **model_kwargs,
583
524
  )
584
525
 
585
526
  requests_info, tokens_info = _get_header_info(result.headers)
@@ -606,12 +547,7 @@ def _embeddings_get_request_resources(input: list[str]) -> dict[str, int]:
606
547
 
607
548
  @pxt.udf(batch_size=32)
608
549
  async def embeddings(
609
- input: Batch[str],
610
- *,
611
- model: str,
612
- dimensions: Optional[int] = None,
613
- user: Optional[str] = None,
614
- timeout: Optional[float] = None,
550
+ input: Batch[str], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
615
551
  ) -> Batch[pxt.Array[(None,), pxt.Float]]:
616
552
  """
617
553
  Creates an embedding vector representing the input text.
@@ -630,10 +566,8 @@ async def embeddings(
630
566
  Args:
631
567
  input: The text to embed.
632
568
  model: The model to use for the embedding.
633
- dimensions: The vector length of the embedding. If not specified, Pixeltable will use
634
- a default value based on the model.
635
-
636
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
569
+ model_kwargs: Additional keyword args for the OpenAI `embeddings` API. For details on the available
570
+ parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
637
571
 
638
572
  Returns:
639
573
  An array representing the application of the given embedding to `input`.
@@ -648,18 +582,16 @@ async def embeddings(
648
582
 
649
583
  >>> tbl.add_embedding_index(embedding=embeddings.using(model='text-embedding-3-small'))
650
584
  """
585
+ if model_kwargs is None:
586
+ model_kwargs = {}
587
+
651
588
  _logger.debug(f'embeddings: batch_size={len(input)}')
652
589
  resource_pool = _rate_limits_pool(model)
653
590
  rate_limits_info = env.Env.get().get_resource_pool_info(
654
591
  resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources)
655
592
  )
656
593
  result = await _openai_client().embeddings.with_raw_response.create(
657
- input=input,
658
- model=model,
659
- dimensions=_opt(dimensions),
660
- user=_opt(user),
661
- encoding_format='float',
662
- timeout=_opt(timeout),
594
+ input=input, model=model, encoding_format='float', **model_kwargs
663
595
  )
664
596
  requests_info, tokens_info = _get_header_info(result.headers)
665
597
  rate_limits_info.record(requests=requests_info, tokens=tokens_info)
@@ -667,7 +599,10 @@ async def embeddings(
667
599
 
668
600
 
669
601
  @embeddings.conditional_return_type
670
- def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
602
+ def _(model: str, model_kwargs: Optional[dict[str, Any]] = None) -> ts.ArrayType:
603
+ dimensions: Optional[int] = None
604
+ if model_kwargs is not None:
605
+ dimensions = model_kwargs.get('dimensions')
671
606
  if dimensions is None:
672
607
  if model not in _embedding_dimensions_cache:
673
608
  # TODO: find some other way to retrieve a sample
@@ -682,14 +617,7 @@ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
682
617
 
683
618
  @pxt.udf
684
619
  async def image_generations(
685
- prompt: str,
686
- *,
687
- model: str = 'dall-e-2',
688
- quality: Optional[str] = None,
689
- size: Optional[str] = None,
690
- style: Optional[str] = None,
691
- user: Optional[str] = None,
692
- timeout: Optional[float] = None,
620
+ prompt: str, *, model: str = 'dall-e-2', model_kwargs: Optional[dict[str, Any]] = None
693
621
  ) -> PIL.Image.Image:
694
622
  """
695
623
  Creates an image given a prompt.
@@ -708,8 +636,8 @@ async def image_generations(
708
636
  Args:
709
637
  prompt: Prompt for the image.
710
638
  model: The model to use for the generations.
711
-
712
- For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/images/create>
639
+ model_kwargs: Additional keyword args for the OpenAI `images/generations` API. For details on the available
640
+ parameters, see: <https://platform.openai.com/docs/api-reference/images/create>
713
641
 
714
642
  Returns:
715
643
  The generated image.
@@ -720,16 +648,12 @@ async def image_generations(
720
648
 
721
649
  >>> tbl.add_computed_column(gen_image=image_generations(tbl.text, model='dall-e-2'))
722
650
  """
651
+ if model_kwargs is None:
652
+ model_kwargs = {}
653
+
723
654
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
724
655
  result = await _openai_client().images.generate(
725
- prompt=prompt,
726
- model=_opt(model),
727
- quality=_opt(quality), # type: ignore
728
- size=_opt(size), # type: ignore
729
- style=_opt(style), # type: ignore
730
- user=_opt(user),
731
- response_format='b64_json',
732
- timeout=_opt(timeout),
656
+ prompt=prompt, model=model, response_format='b64_json', **model_kwargs
733
657
  )
734
658
  b64_str = result.data[0].b64_json
735
659
  b64_bytes = base64.b64decode(b64_str)
@@ -739,9 +663,11 @@ async def image_generations(
739
663
 
740
664
 
741
665
  @image_generations.conditional_return_type
742
- def _(size: Optional[str] = None) -> ts.ImageType:
743
- if size is None:
666
+ def _(model_kwargs: Optional[dict[str, Any]] = None) -> ts.ImageType:
667
+ if model_kwargs is None or 'size' not in model_kwargs:
668
+ # default size is 1024x1024
744
669
  return ts.ImageType(size=(1024, 1024))
670
+ size = model_kwargs['size']
745
671
  x_pos = size.find('x')
746
672
  if x_pos == -1:
747
673
  return ts.ImageType()
@@ -787,7 +713,7 @@ async def moderations(input: str, *, model: str = 'omni-moderation-latest') -> d
787
713
 
788
714
  >>> tbl.add_computed_column(moderations=moderations(tbl.text, model='text-moderation-stable'))
789
715
  """
790
- result = await _openai_client().moderations.create(input=input, model=_opt(model))
716
+ result = await _openai_client().moderations.create(input=input, model=model)
791
717
  return result.dict()
792
718
 
793
719
 
@@ -826,15 +752,6 @@ def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
826
752
  return pxt_tool_calls
827
753
 
828
754
 
829
- _T = TypeVar('_T')
830
-
831
-
832
- def _opt(arg: _T) -> Union[_T, 'openai.NotGiven']:
833
- import openai
834
-
835
- return arg if arg is not None else openai.NOT_GIVEN
836
-
837
-
838
755
  __all__ = local_public_names(__name__)
839
756
 
840
757