pixeltable 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (150) hide show
  1. pixeltable/__init__.py +64 -11
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/catalog.py +50 -27
  5. pixeltable/catalog/column.py +27 -11
  6. pixeltable/catalog/dir.py +6 -4
  7. pixeltable/catalog/globals.py +8 -1
  8. pixeltable/catalog/insertable_table.py +22 -12
  9. pixeltable/catalog/named_function.py +10 -6
  10. pixeltable/catalog/path.py +3 -2
  11. pixeltable/catalog/path_dict.py +8 -6
  12. pixeltable/catalog/schema_object.py +2 -1
  13. pixeltable/catalog/table.py +121 -101
  14. pixeltable/catalog/table_version.py +291 -142
  15. pixeltable/catalog/table_version_path.py +8 -5
  16. pixeltable/catalog/view.py +67 -26
  17. pixeltable/dataframe.py +106 -81
  18. pixeltable/env.py +28 -24
  19. pixeltable/exec/__init__.py +2 -2
  20. pixeltable/exec/aggregation_node.py +10 -4
  21. pixeltable/exec/cache_prefetch_node.py +5 -3
  22. pixeltable/exec/component_iteration_node.py +9 -9
  23. pixeltable/exec/data_row_batch.py +21 -10
  24. pixeltable/exec/exec_context.py +10 -3
  25. pixeltable/exec/exec_node.py +23 -12
  26. pixeltable/exec/expr_eval/evaluators.py +13 -7
  27. pixeltable/exec/expr_eval/expr_eval_node.py +24 -15
  28. pixeltable/exec/expr_eval/globals.py +30 -7
  29. pixeltable/exec/expr_eval/row_buffer.py +5 -6
  30. pixeltable/exec/expr_eval/schedulers.py +151 -31
  31. pixeltable/exec/in_memory_data_node.py +8 -7
  32. pixeltable/exec/row_update_node.py +15 -5
  33. pixeltable/exec/sql_node.py +56 -27
  34. pixeltable/exprs/__init__.py +2 -2
  35. pixeltable/exprs/arithmetic_expr.py +57 -26
  36. pixeltable/exprs/array_slice.py +1 -1
  37. pixeltable/exprs/column_property_ref.py +2 -1
  38. pixeltable/exprs/column_ref.py +20 -15
  39. pixeltable/exprs/comparison.py +6 -2
  40. pixeltable/exprs/compound_predicate.py +1 -3
  41. pixeltable/exprs/data_row.py +2 -2
  42. pixeltable/exprs/expr.py +108 -72
  43. pixeltable/exprs/expr_dict.py +2 -1
  44. pixeltable/exprs/expr_set.py +3 -1
  45. pixeltable/exprs/function_call.py +39 -41
  46. pixeltable/exprs/globals.py +1 -0
  47. pixeltable/exprs/in_predicate.py +2 -2
  48. pixeltable/exprs/inline_expr.py +20 -17
  49. pixeltable/exprs/json_mapper.py +4 -2
  50. pixeltable/exprs/json_path.py +12 -18
  51. pixeltable/exprs/literal.py +5 -9
  52. pixeltable/exprs/method_ref.py +1 -0
  53. pixeltable/exprs/object_ref.py +1 -1
  54. pixeltable/exprs/row_builder.py +32 -17
  55. pixeltable/exprs/rowid_ref.py +14 -5
  56. pixeltable/exprs/similarity_expr.py +11 -6
  57. pixeltable/exprs/sql_element_cache.py +1 -1
  58. pixeltable/exprs/type_cast.py +24 -9
  59. pixeltable/ext/__init__.py +1 -0
  60. pixeltable/ext/functions/__init__.py +1 -0
  61. pixeltable/ext/functions/whisperx.py +2 -2
  62. pixeltable/ext/functions/yolox.py +11 -11
  63. pixeltable/func/aggregate_function.py +17 -13
  64. pixeltable/func/callable_function.py +6 -6
  65. pixeltable/func/expr_template_function.py +15 -14
  66. pixeltable/func/function.py +16 -16
  67. pixeltable/func/function_registry.py +11 -8
  68. pixeltable/func/globals.py +4 -2
  69. pixeltable/func/query_template_function.py +12 -13
  70. pixeltable/func/signature.py +18 -9
  71. pixeltable/func/tools.py +10 -17
  72. pixeltable/func/udf.py +106 -11
  73. pixeltable/functions/__init__.py +21 -2
  74. pixeltable/functions/anthropic.py +16 -12
  75. pixeltable/functions/fireworks.py +63 -5
  76. pixeltable/functions/gemini.py +13 -3
  77. pixeltable/functions/globals.py +18 -6
  78. pixeltable/functions/huggingface.py +20 -38
  79. pixeltable/functions/image.py +7 -3
  80. pixeltable/functions/json.py +1 -0
  81. pixeltable/functions/llama_cpp.py +1 -4
  82. pixeltable/functions/mistralai.py +31 -20
  83. pixeltable/functions/ollama.py +4 -18
  84. pixeltable/functions/openai.py +231 -113
  85. pixeltable/functions/replicate.py +11 -10
  86. pixeltable/functions/string.py +70 -7
  87. pixeltable/functions/timestamp.py +21 -8
  88. pixeltable/functions/together.py +66 -52
  89. pixeltable/functions/video.py +1 -0
  90. pixeltable/functions/vision.py +14 -11
  91. pixeltable/functions/whisper.py +2 -1
  92. pixeltable/globals.py +60 -26
  93. pixeltable/index/__init__.py +1 -1
  94. pixeltable/index/btree.py +5 -3
  95. pixeltable/index/embedding_index.py +15 -14
  96. pixeltable/io/__init__.py +1 -1
  97. pixeltable/io/external_store.py +30 -25
  98. pixeltable/io/fiftyone.py +6 -14
  99. pixeltable/io/globals.py +33 -27
  100. pixeltable/io/hf_datasets.py +2 -1
  101. pixeltable/io/label_studio.py +77 -68
  102. pixeltable/io/pandas.py +36 -23
  103. pixeltable/io/parquet.py +9 -12
  104. pixeltable/iterators/__init__.py +1 -0
  105. pixeltable/iterators/audio.py +205 -0
  106. pixeltable/iterators/document.py +19 -8
  107. pixeltable/iterators/image.py +6 -24
  108. pixeltable/iterators/string.py +3 -6
  109. pixeltable/iterators/video.py +1 -7
  110. pixeltable/metadata/__init__.py +7 -1
  111. pixeltable/metadata/converters/convert_10.py +2 -2
  112. pixeltable/metadata/converters/convert_15.py +1 -5
  113. pixeltable/metadata/converters/convert_16.py +2 -4
  114. pixeltable/metadata/converters/convert_17.py +2 -4
  115. pixeltable/metadata/converters/convert_18.py +2 -4
  116. pixeltable/metadata/converters/convert_19.py +2 -5
  117. pixeltable/metadata/converters/convert_20.py +1 -4
  118. pixeltable/metadata/converters/convert_21.py +4 -6
  119. pixeltable/metadata/converters/convert_22.py +1 -0
  120. pixeltable/metadata/converters/convert_23.py +5 -5
  121. pixeltable/metadata/converters/convert_24.py +12 -13
  122. pixeltable/metadata/converters/convert_26.py +23 -0
  123. pixeltable/metadata/converters/util.py +3 -4
  124. pixeltable/metadata/notes.py +1 -0
  125. pixeltable/metadata/schema.py +13 -2
  126. pixeltable/plan.py +173 -98
  127. pixeltable/share/__init__.py +0 -0
  128. pixeltable/share/packager.py +218 -0
  129. pixeltable/store.py +42 -26
  130. pixeltable/type_system.py +102 -75
  131. pixeltable/utils/arrow.py +7 -8
  132. pixeltable/utils/coco.py +16 -17
  133. pixeltable/utils/code.py +1 -1
  134. pixeltable/utils/console_output.py +6 -3
  135. pixeltable/utils/description_helper.py +7 -7
  136. pixeltable/utils/documents.py +3 -1
  137. pixeltable/utils/filecache.py +12 -7
  138. pixeltable/utils/http_server.py +9 -8
  139. pixeltable/utils/iceberg.py +14 -0
  140. pixeltable/utils/media_store.py +3 -2
  141. pixeltable/utils/pytorch.py +11 -14
  142. pixeltable/utils/s3.py +1 -0
  143. pixeltable/utils/sql.py +1 -0
  144. pixeltable/utils/transactional_directory.py +2 -2
  145. {pixeltable-0.3.2.dist-info → pixeltable-0.3.4.dist-info}/METADATA +9 -9
  146. pixeltable-0.3.4.dist-info/RECORD +166 -0
  147. pixeltable-0.3.2.dist-info/RECORD +0 -161
  148. {pixeltable-0.3.2.dist-info → pixeltable-0.3.4.dist-info}/LICENSE +0 -0
  149. {pixeltable-0.3.2.dist-info → pixeltable-0.3.4.dist-info}/WHEEL +0 -0
  150. {pixeltable-0.3.2.dist-info → pixeltable-0.3.4.dist-info}/entry_points.txt +0 -0
@@ -46,7 +46,7 @@ def sentence_transformer(
46
46
  Add a computed column that applies the model `all-mpnet-base-2` to an existing Pixeltable column `tbl.sentence`
47
47
  of the table `tbl`:
48
48
 
49
- >>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
49
+ >>> tbl.add_computed_column(result=sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2'))
50
50
  """
51
51
  env.Env.get().require_package('sentence_transformers')
52
52
  device = resolve_torch_device('auto')
@@ -111,9 +111,9 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
111
111
  Add a computed column that applies the model `ms-marco-MiniLM-L-4-v2` to the sentences in
112
112
  columns `tbl.sentence1` and `tbl.sentence2`:
113
113
 
114
- >>> tbl['result'] = sentence_transformer(
115
- tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
116
- )
114
+ >>> tbl.add_computed_column(result=sentence_transformer(
115
+ ... tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
116
+ ... ))
117
117
  """
118
118
  env.Env.get().require_package('sentence_transformers')
119
119
  device = resolve_torch_device('auto')
@@ -215,11 +215,7 @@ def _(model_id: str) -> pxt.ArrayType:
215
215
 
216
216
  @pxt.udf(batch_size=4)
217
217
  def detr_for_object_detection(
218
- image: Batch[PIL.Image.Image],
219
- *,
220
- model_id: str,
221
- threshold: float = 0.5,
222
- revision: str = 'no_timm',
218
+ image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5, revision: str = 'no_timm'
223
219
  ) -> Batch[dict]:
224
220
  """
225
221
  Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
@@ -250,11 +246,11 @@ def detr_for_object_detection(
250
246
  Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
251
247
  Pixeltable column `image` of the table `tbl`:
252
248
 
253
- >>> tbl['detections'] = detr_for_object_detection(
249
+ >>> tbl.add_computed_column(detections=detr_for_object_detection(
254
250
  ... tbl.image,
255
251
  ... model_id='facebook/detr-resnet-50',
256
252
  ... threshold=0.8
257
- ... )
253
+ ... ))
258
254
  """
259
255
  env.Env.get().require_package('transformers')
260
256
  device = resolve_torch_device('auto')
@@ -287,10 +283,7 @@ def detr_for_object_detection(
287
283
 
288
284
  @pxt.udf(batch_size=4)
289
285
  def vit_for_image_classification(
290
- image: Batch[PIL.Image.Image],
291
- *,
292
- model_id: str,
293
- top_k: int = 5
286
+ image: Batch[PIL.Image.Image], *, model_id: str, top_k: int = 5
294
287
  ) -> Batch[dict[str, Any]]:
295
288
  """
296
289
  Computes image classifications for the specified image using a Vision Transformer (ViT) model.
@@ -326,11 +319,11 @@ def vit_for_image_classification(
326
319
  Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
327
320
  Pixeltable column `image` of the table `tbl`, returning the 10 most likely classes for each image:
328
321
 
329
- >>> tbl['image_class'] = vit_for_image_classification(
322
+ >>> tbl.add_computed_column(image_class=vit_for_image_classification(
330
323
  ... tbl.image,
331
324
  ... model_id='google/vit-base-patch16-224',
332
325
  ... top_k=10
333
- ... )
326
+ ... ))
334
327
  """
335
328
  env.Env.get().require_package('transformers')
336
329
  device = resolve_torch_device('auto')
@@ -362,12 +355,7 @@ def vit_for_image_classification(
362
355
 
363
356
 
364
357
  @pxt.udf
365
- def speech2text_for_conditional_generation(
366
- audio: pxt.Audio,
367
- *,
368
- model_id: str,
369
- language: Optional[str] = None,
370
- ) -> str:
358
+ def speech2text_for_conditional_generation(audio: pxt.Audio, *, model_id: str, language: Optional[str] = None) -> str:
371
359
  """
372
360
  Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
373
361
  pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
@@ -390,19 +378,19 @@ def speech2text_for_conditional_generation(
390
378
  Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
391
379
  Pixeltable column `audio` of the table `tbl`:
392
380
 
393
- >>> tbl['transcription'] = speech2text_for_conditional_generation(
381
+ >>> tbl.add_computed_column(transcription=speech2text_for_conditional_generation(
394
382
  ... tbl.audio,
395
383
  ... model_id='facebook/s2t-small-librispeech-asr'
396
- ... )
384
+ ... ))
397
385
 
398
386
  Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
399
387
  Pixeltable column `audio` of the table `tbl`, translating the audio to French:
400
388
 
401
- >>> tbl['translation'] = speech2text_for_conditional_generation(
389
+ >>> tbl.add_computed_column(translation=speech2text_for_conditional_generation(
402
390
  ... tbl.audio,
403
391
  ... model_id='facebook/s2t-medium-mustc-multilingual-st',
404
392
  ... language='fr'
405
- ... )
393
+ ... ))
406
394
  """
407
395
  env.Env.get().require_package('transformers')
408
396
  env.Env.get().require_package('torchaudio')
@@ -419,7 +407,8 @@ def speech2text_for_conditional_generation(
419
407
  if language is not None and language not in processor.tokenizer.lang_code_to_id:
420
408
  raise excs.Error(
421
409
  f"Language code '{language}' is not supported by the model '{model_id}'. "
422
- f"Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}")
410
+ f'Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}'
411
+ )
423
412
 
424
413
  forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
425
414
 
@@ -439,11 +428,7 @@ def speech2text_for_conditional_generation(
439
428
  assert waveform.dim() == 1
440
429
 
441
430
  with torch.no_grad():
442
- inputs = processor(
443
- waveform,
444
- sampling_rate=model_sampling_rate,
445
- return_tensors='pt'
446
- )
431
+ inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
447
432
  generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
448
433
 
449
434
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
@@ -466,7 +451,7 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
466
451
  Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
467
452
  is the image for which detections were computed:
468
453
 
469
- >>> tbl['detections_coco'] = detr_to_coco(tbl.image, tbl.detections)
454
+ >>> tbl.add_computed_column(detections_coco=detr_to_coco(tbl.image, tbl.detections))
470
455
  """
471
456
  bboxes, labels = detr_info['boxes'], detr_info['labels']
472
457
  annotations = [
@@ -480,10 +465,7 @@ T = TypeVar('T')
480
465
 
481
466
 
482
467
  def _lookup_model(
483
- model_id: str,
484
- create: Callable[..., T],
485
- device: Optional[str] = None,
486
- pass_device_to_create: bool = False
468
+ model_id: str, create: Callable[..., T], device: Optional[str] = None, pass_device_to_create: bool = False
487
469
  ) -> T:
488
470
  from torch import nn
489
471
 
@@ -56,6 +56,7 @@ def blend(im1: PIL.Image.Image, im2: PIL.Image.Image, alpha: float) -> PIL.Image
56
56
  """
57
57
  pass
58
58
 
59
+
59
60
  @pxt.udf(substitute_fn=PIL.Image.composite, is_method=True)
60
61
  def composite(image1: PIL.Image.Image, image2: PIL.Image.Image, mask: PIL.Image.Image) -> PIL.Image.Image:
61
62
  """
@@ -68,6 +69,7 @@ def composite(image1: PIL.Image.Image, image2: PIL.Image.Image, mask: PIL.Image.
68
69
 
69
70
  # PIL.Image.Image methods
70
71
 
72
+
71
73
  # Image.convert()
72
74
  @pxt.udf(is_method=True)
73
75
  def convert(self: PIL.Image.Image, mode: str) -> PIL.Image.Image:
@@ -108,7 +110,9 @@ def _(self: Expr, box: tuple[int, int, int, int]) -> pxt.ColumnType:
108
110
  input_type = self.col_type
109
111
  assert isinstance(input_type, pxt.ImageType)
110
112
  if (isinstance(box, list) or isinstance(box, tuple)) and len(box) == 4 and all(isinstance(x, int) for x in box):
111
- return pxt.ImageType(size=(box[2] - box[0], box[3] - box[1]), mode=input_type.mode, nullable=input_type.nullable)
113
+ return pxt.ImageType(
114
+ size=(box[2] - box[0], box[3] - box[1]), mode=input_type.mode, nullable=input_type.nullable
115
+ )
112
116
  return pxt.ImageType(mode=input_type.mode, nullable=input_type.nullable) # we can't compute the size statically
113
117
 
114
118
 
@@ -339,7 +343,7 @@ def getprojection(self: PIL.Image.Image) -> tuple[int]:
339
343
 
340
344
  @pxt.udf(substitute_fn=PIL.Image.Image.histogram, is_method=True)
341
345
  def histogram(
342
- self: PIL.Image.Image, mask: Optional[PIL.Image.Image] = None, extrema: Optional[list] = None
346
+ self: PIL.Image.Image, mask: Optional[PIL.Image.Image] = None, extrema: Optional[list] = None
343
347
  ) -> list[int]:
344
348
  """
345
349
  Return a histogram for the image.
@@ -375,7 +379,7 @@ def quantize(
375
379
  kmeans: The number of k-means clusters to use.
376
380
  palette: The palette to use.
377
381
  dither: The dithering method. See the [Pillow documentation](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.quantize) for a list of supported methods.
378
- """
382
+ """
379
383
  pass
380
384
 
381
385
 
@@ -21,6 +21,7 @@ class make_list(pxt.Aggregator):
21
21
  """
22
22
  Collects arguments into a list.
23
23
  """
24
+
24
25
  def __init__(self) -> None:
25
26
  self.output: list[Any] = []
26
27
 
@@ -87,10 +87,7 @@ def _lookup_pretrained_model(repo_id: str, filename: Optional[str], n_gpu_layers
87
87
  key = (repo_id, filename, n_gpu_layers)
88
88
  if key not in _model_cache:
89
89
  llm = llama_cpp.Llama.from_pretrained(
90
- repo_id=repo_id,
91
- filename=filename,
92
- n_gpu_layers=n_gpu_layers,
93
- verbose=False,
90
+ repo_id=repo_id, filename=filename, n_gpu_layers=n_gpu_layers, verbose=False
94
91
  )
95
92
  _model_cache[key] = llm
96
93
  return _model_cache[key]
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
21
21
  @register_client('mistral')
22
22
  def _(api_key: str) -> 'mistralai.Mistral':
23
23
  import mistralai
24
+
24
25
  return mistralai.Mistral(api_key=api_key)
25
26
 
26
27
 
@@ -28,8 +29,8 @@ def _mistralai_client() -> 'mistralai.Mistral':
28
29
  return Env.get().get_client('mistral')
29
30
 
30
31
 
31
- @pxt.udf
32
- def chat_completions(
32
+ @pxt.udf(resource_pool='request-rate:mistral')
33
+ async def chat_completions(
33
34
  messages: list[dict[str, str]],
34
35
  *,
35
36
  model: str,
@@ -47,6 +48,10 @@ def chat_completions(
47
48
  Equivalent to the Mistral AI `chat/completions` API endpoint.
48
49
  For additional details, see: <https://docs.mistral.ai/api/#tag/chat>
49
50
 
51
+ Request throttling:
52
+ Applies the rate limit set in the config (section `mistral`, key `rate_limit`). If no rate
53
+ limit is configured, uses a default of 600 RPM.
54
+
50
55
  __Requirements:__
51
56
 
52
57
  - `pip install mistralai`
@@ -65,10 +70,10 @@ def chat_completions(
65
70
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
66
71
 
67
72
  >>> messages = [{'role': 'user', 'content': tbl.prompt}]
68
- ... tbl['response'] = completions(messages, model='mistral-latest-small')
73
+ ... tbl.add_computed_column(response=completions(messages, model='mistral-latest-small'))
69
74
  """
70
75
  Env.get().require_package('mistralai')
71
- return _mistralai_client().chat.complete(
76
+ result = await _mistralai_client().chat.complete_async(
72
77
  messages=messages, # type: ignore[arg-type]
73
78
  model=model,
74
79
  temperature=temperature,
@@ -78,11 +83,12 @@ def chat_completions(
78
83
  random_seed=_opt(random_seed),
79
84
  response_format=response_format, # type: ignore[arg-type]
80
85
  safe_prompt=safe_prompt,
81
- ).dict()
86
+ )
87
+ return result.dict()
82
88
 
83
89
 
84
- @pxt.udf
85
- def fim_completions(
90
+ @pxt.udf(resource_pool='request-rate:mistral')
91
+ async def fim_completions(
86
92
  prompt: str,
87
93
  *,
88
94
  model: str,
@@ -100,6 +106,10 @@ def fim_completions(
100
106
  Equivalent to the Mistral AI `fim/completions` API endpoint.
101
107
  For additional details, see: <https://docs.mistral.ai/api/#tag/fim>
102
108
 
109
+ Request throttling:
110
+ Applies the rate limit set in the config (section `mistral`, key `rate_limit`). If no rate
111
+ limit is configured, uses a default of 600 RPM.
112
+
103
113
  __Requirements:__
104
114
 
105
115
  - `pip install mistralai`
@@ -117,10 +127,10 @@ def fim_completions(
117
127
  Add a computed column that applies the model `codestral-latest`
118
128
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
119
129
 
120
- >>> tbl['response'] = completions(tbl.prompt, model='codestral-latest')
130
+ >>> tbl.add_computed_column(response=completions(tbl.prompt, model='codestral-latest'))
121
131
  """
122
132
  Env.get().require_package('mistralai')
123
- return _mistralai_client().fim.complete(
133
+ result = await _mistralai_client().fim.complete_async(
124
134
  prompt=prompt,
125
135
  model=model,
126
136
  temperature=temperature,
@@ -129,23 +139,26 @@ def fim_completions(
129
139
  min_tokens=_opt(min_tokens),
130
140
  stop=stop,
131
141
  random_seed=_opt(random_seed),
132
- suffix=_opt(suffix)
133
- ).dict()
142
+ suffix=_opt(suffix),
143
+ )
144
+ return result.dict()
134
145
 
135
146
 
136
- _embedding_dimensions_cache: dict[str, int] = {
137
- 'mistral-embed': 1024
138
- }
147
+ _embedding_dimensions_cache: dict[str, int] = {'mistral-embed': 1024}
139
148
 
140
149
 
141
- @pxt.udf(batch_size=16)
142
- def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
150
+ @pxt.udf(batch_size=16, resource_pool='request-rate:mistral')
151
+ async def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
143
152
  """
144
153
  Embeddings API.
145
154
 
146
155
  Equivalent to the Mistral AI `embeddings` API endpoint.
147
156
  For additional details, see: <https://docs.mistral.ai/api/#tag/embeddings>
148
157
 
158
+ Request throttling:
159
+ Applies the rate limit set in the config (section `mistral`, key `rate_limit`). If no rate
160
+ limit is configured, uses a default of 600 RPM.
161
+
149
162
  __Requirements:__
150
163
 
151
164
  - `pip install mistralai`
@@ -158,10 +171,7 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt
158
171
  An array representing the application of the given embedding to `input`.
159
172
  """
160
173
  Env.get().require_package('mistralai')
161
- result = _mistralai_client().embeddings.create(
162
- inputs=input,
163
- model=model,
164
- )
174
+ result = _mistralai_client().embeddings.create(inputs=input, model=model)
165
175
  return [np.array(data.embedding, dtype=np.float64) for data in result.data]
166
176
 
167
177
 
@@ -176,6 +186,7 @@ _T = TypeVar('_T')
176
186
 
177
187
  def _opt(arg: Optional[_T]) -> Union[_T, 'mistralai.types.basemodel.Unset']:
178
188
  from mistralai.types import UNSET
189
+
179
190
  return arg if arg is not None else UNSET
180
191
 
181
192
 
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
14
14
  @env.register_client('ollama')
15
15
  def _(host: str) -> 'ollama.Client':
16
16
  import ollama
17
+
17
18
  return ollama.Client(host=host)
18
19
 
19
20
 
@@ -97,22 +98,12 @@ def chat(
97
98
  import ollama
98
99
 
99
100
  client = _ollama_client() or ollama
100
- return client.chat(
101
- model=model,
102
- messages=messages,
103
- tools=tools,
104
- format=format,
105
- options=options,
106
- ).dict() # type: ignore[call-overload]
101
+ return client.chat(model=model, messages=messages, tools=tools, format=format, options=options).dict() # type: ignore[call-overload]
107
102
 
108
103
 
109
104
  @pxt.udf(batch_size=16)
110
105
  def embed(
111
- input: Batch[str],
112
- *,
113
- model: str,
114
- truncate: bool = True,
115
- options: Optional[dict] = None,
106
+ input: Batch[str], *, model: str, truncate: bool = True, options: Optional[dict] = None
116
107
  ) -> Batch[pxt.Array[(None,), pxt.Float]]:
117
108
  """
118
109
  Generate embeddings from a model.
@@ -131,12 +122,7 @@ def embed(
131
122
  import ollama
132
123
 
133
124
  client = _ollama_client() or ollama
134
- results = client.embed(
135
- model=model,
136
- input=input,
137
- truncate=truncate,
138
- options=options,
139
- ).dict()
125
+ results = client.embed(model=model, input=input, truncate=truncate, options=options).dict()
140
126
  return [np.array(data, dtype=np.float64) for data in results['embeddings']]
141
127
 
142
128