pixeltable 0.2.30__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (60) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/table.py +212 -173
  4. pixeltable/catalog/table_version.py +2 -1
  5. pixeltable/catalog/view.py +3 -5
  6. pixeltable/dataframe.py +52 -39
  7. pixeltable/env.py +94 -5
  8. pixeltable/exec/__init__.py +1 -1
  9. pixeltable/exec/aggregation_node.py +3 -3
  10. pixeltable/exec/cache_prefetch_node.py +13 -7
  11. pixeltable/exec/component_iteration_node.py +3 -9
  12. pixeltable/exec/data_row_batch.py +17 -5
  13. pixeltable/exec/exec_node.py +32 -12
  14. pixeltable/exec/expr_eval/__init__.py +1 -0
  15. pixeltable/exec/expr_eval/evaluators.py +245 -0
  16. pixeltable/exec/expr_eval/expr_eval_node.py +404 -0
  17. pixeltable/exec/expr_eval/globals.py +114 -0
  18. pixeltable/exec/expr_eval/row_buffer.py +76 -0
  19. pixeltable/exec/expr_eval/schedulers.py +232 -0
  20. pixeltable/exec/in_memory_data_node.py +2 -2
  21. pixeltable/exec/row_update_node.py +14 -14
  22. pixeltable/exec/sql_node.py +2 -2
  23. pixeltable/exprs/column_ref.py +5 -1
  24. pixeltable/exprs/data_row.py +50 -40
  25. pixeltable/exprs/expr.py +57 -12
  26. pixeltable/exprs/function_call.py +54 -19
  27. pixeltable/exprs/inline_expr.py +12 -21
  28. pixeltable/exprs/literal.py +25 -8
  29. pixeltable/exprs/row_builder.py +23 -0
  30. pixeltable/exprs/similarity_expr.py +4 -4
  31. pixeltable/func/__init__.py +5 -5
  32. pixeltable/func/aggregate_function.py +4 -0
  33. pixeltable/func/callable_function.py +54 -6
  34. pixeltable/func/expr_template_function.py +5 -1
  35. pixeltable/func/function.py +54 -13
  36. pixeltable/func/query_template_function.py +56 -10
  37. pixeltable/func/tools.py +51 -14
  38. pixeltable/func/udf.py +7 -1
  39. pixeltable/functions/__init__.py +1 -1
  40. pixeltable/functions/anthropic.py +108 -21
  41. pixeltable/functions/gemini.py +2 -6
  42. pixeltable/functions/huggingface.py +10 -28
  43. pixeltable/functions/openai.py +225 -28
  44. pixeltable/globals.py +8 -5
  45. pixeltable/index/embedding_index.py +90 -38
  46. pixeltable/io/label_studio.py +1 -1
  47. pixeltable/metadata/__init__.py +1 -1
  48. pixeltable/metadata/converters/convert_24.py +11 -2
  49. pixeltable/metadata/converters/convert_25.py +19 -0
  50. pixeltable/metadata/notes.py +1 -0
  51. pixeltable/plan.py +24 -9
  52. pixeltable/store.py +6 -0
  53. pixeltable/type_system.py +4 -7
  54. pixeltable/utils/arrow.py +3 -3
  55. {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/METADATA +5 -11
  56. {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/RECORD +59 -53
  57. pixeltable/exec/expr_eval_node.py +0 -232
  58. {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/LICENSE +0 -0
  59. {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/WHEEL +0 -0
  60. {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/entry_points.txt +0 -0
@@ -144,9 +144,9 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
144
144
 
145
145
 
146
146
  @pxt.udf(batch_size=32)
147
- def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
147
+ def clip(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
148
148
  """
149
- Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
149
+ Computes a CLIP embedding for the specified text or image. `model_id` should be a reference to a pretrained
150
150
  [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
151
151
 
152
152
  __Requirements:__
@@ -164,7 +164,11 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), px
164
164
  Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
165
165
  Pixeltable column `tbl.text` of the table `tbl`:
166
166
 
167
- >>> tbl['result'] = clip_text(tbl.text, model_id='openai/clip-vit-base-patch32')
167
+ >>> tbl.add_computed_column(
168
+ ... result=clip(tbl.text, model_id='openai/clip-vit-base-patch32')
169
+ ... )
170
+
171
+ The same would work with an image column `tbl.image` in place of `tbl.text`.
168
172
  """
169
173
  env.Env.get().require_package('transformers')
170
174
  device = resolve_torch_device('auto')
@@ -181,29 +185,8 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), px
181
185
  return [embeddings[i] for i in range(embeddings.shape[0])]
182
186
 
183
187
 
184
- @pxt.udf(batch_size=32)
185
- def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
186
- """
187
- Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
188
- [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
189
-
190
- __Requirements:__
191
-
192
- - `pip install torch transformers`
193
-
194
- Args:
195
- image: The image to embed.
196
- model_id: The pretrained model to use for the embedding.
197
-
198
- Returns:
199
- An array containing the output of the embedding model.
200
-
201
- Examples:
202
- Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
203
- Pixeltable column `image` of the table `tbl`:
204
-
205
- >>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
206
- """
188
+ @clip.overload
189
+ def _(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
207
190
  env.Env.get().require_package('transformers')
208
191
  device = resolve_torch_device('auto')
209
192
  import torch
@@ -219,8 +202,7 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Arr
219
202
  return [embeddings[i] for i in range(embeddings.shape[0])]
220
203
 
221
204
 
222
- @clip_text.conditional_return_type
223
- @clip_image.conditional_return_type
205
+ @clip.conditional_return_type
224
206
  def _(model_id: str) -> pxt.ArrayType:
225
207
  try:
226
208
  from transformers import CLIPModel
@@ -6,14 +6,18 @@ the [Working with OpenAI](https://pixeltable.readme.io/docs/working-with-openai)
6
6
  """
7
7
 
8
8
  import base64
9
+ import datetime
9
10
  import io
10
11
  import json
12
+ import logging
11
13
  import pathlib
14
+ import re
12
15
  import uuid
13
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
16
+ from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union, cast, Any, Type
14
17
 
15
- import numpy as np
16
18
  import PIL.Image
19
+ import httpx
20
+ import numpy as np
17
21
  import tenacity
18
22
 
19
23
  import pixeltable as pxt
@@ -24,15 +28,28 @@ from pixeltable.utils.code import local_public_names
24
28
  if TYPE_CHECKING:
25
29
  import openai
26
30
 
31
+ _logger = logging.getLogger('pixeltable')
32
+
27
33
 
28
34
  @env.register_client('openai')
29
- def _(api_key: str) -> 'openai.OpenAI':
35
+ def _(api_key: str) -> tuple['openai.OpenAI', 'openai.AsyncOpenAI']:
30
36
  import openai
31
- return openai.OpenAI(api_key=api_key)
37
+ return (
38
+ openai.OpenAI(api_key=api_key),
39
+ openai.AsyncOpenAI(
40
+ api_key=api_key,
41
+ # recommended to increase limits for async client to avoid connection errors
42
+ http_client=httpx.AsyncClient(limits=httpx.Limits(max_keepalive_connections=100, max_connections=500)),
43
+ )
44
+ )
32
45
 
33
46
 
34
47
  def _openai_client() -> 'openai.OpenAI':
35
- return env.Env.get().get_client('openai')
48
+ return env.Env.get().get_client('openai')[0]
49
+
50
+
51
+ def _async_openai_client() -> 'openai.AsyncOpenAI':
52
+ return env.Env.get().get_client('openai')[1]
36
53
 
37
54
 
38
55
  # Exponential backoff decorator using tenacity.
@@ -47,13 +64,128 @@ def _retry(fn: Callable) -> Callable:
47
64
  )(fn)
48
65
 
49
66
 
67
+ # models that share rate limits; see https://platform.openai.com/settings/organization/limits for details
68
+ _shared_rate_limits = {
69
+ 'gpt-4-turbo': [
70
+ 'gpt-4-turbo',
71
+ 'gpt-4-turbo-latest',
72
+ 'gpt-4-turbo-2024-04-09',
73
+ 'gpt-4-turbo-preview',
74
+ 'gpt-4-0125-preview',
75
+ 'gpt-4-1106-preview'
76
+ ],
77
+ 'gpt-4o': [
78
+ 'gpt-4o',
79
+ 'gpt-4o-latest',
80
+ 'gpt-4o-2024-05-13',
81
+ 'gpt-4o-2024-08-06',
82
+ 'gpt-4o-2024-11-20',
83
+ 'gpt-4o-audio-preview',
84
+ 'gpt-4o-audio-preview-2024-10-01',
85
+ 'gpt-4o-audio-preview-2024-12-17'
86
+ ],
87
+ 'gpt-4o-mini': [
88
+ 'gpt-4o-mini',
89
+ 'gpt-4o-mini-latest',
90
+ 'gpt-4o-mini-2024-07-18',
91
+ 'gpt-4o-mini-audio-preview',
92
+ 'gpt-4o-mini-audio-preview-2024-12-17'
93
+ ],
94
+ 'gpt-4o-mini-realtime-preview': [
95
+ 'gpt-4o-mini-realtime-preview',
96
+ 'gpt-4o-mini-realtime-preview-latest',
97
+ 'gpt-4o-mini-realtime-preview-2024-12-17'
98
+ ]
99
+ }
100
+
101
+
102
+ def _resource_pool(model: str) -> str:
103
+ for model_family, models in _shared_rate_limits.items():
104
+ if model in models:
105
+ return f'rate-limits:openai:{model_family}'
106
+ return f'rate-limits:openai:{model}'
107
+
108
+
109
+ class OpenAIRateLimitsInfo(env.RateLimitsInfo):
110
+ retryable_errors: tuple[Type[Exception], ...]
111
+
112
+ def __init__(self, get_request_resources: Callable[..., dict[str, int]]):
113
+ super().__init__(get_request_resources)
114
+ import openai
115
+ self.retryable_errors = (
116
+ openai.RateLimitError, openai.APITimeoutError, openai.UnprocessableEntityError, openai.InternalServerError
117
+ )
118
+
119
+ def get_retry_delay(self, exc: Exception) -> Optional[float]:
120
+ import openai
121
+
122
+ if not isinstance(exc, self.retryable_errors):
123
+ return None
124
+ assert isinstance(exc, openai.APIError)
125
+ return 1.0
126
+
127
+
128
+ # RE pattern for duration in '*-reset' headers;
129
+ # examples: 1d2h3ms, 4m5.6s; # fractional seconds can be reported as 0.5s or 500ms
130
+ _header_duration_pattern = re.compile(r'(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)ms)|(?:(\d+)m)?(?:([\d.]+)s)?')
131
+
132
+
133
+ def _parse_header_duration(duration_str):
134
+ match = _header_duration_pattern.match(duration_str)
135
+ if not match:
136
+ raise ValueError("Invalid duration format")
137
+
138
+ days = int(match.group(1) or 0)
139
+ hours = int(match.group(2) or 0)
140
+ milliseconds = int(match.group(3) or 0)
141
+ minutes = int(match.group(4) or 0)
142
+ seconds = float(match.group(5) or 0)
143
+
144
+ return datetime.timedelta(
145
+ days=days,
146
+ hours=hours,
147
+ minutes=minutes,
148
+ seconds=seconds,
149
+ milliseconds=milliseconds
150
+ )
151
+
152
+
153
+ def _get_header_info(
154
+ headers: httpx.Headers, *, requests: bool = True, tokens: bool = True
155
+ ) -> tuple[Optional[tuple[int, int, datetime.datetime]], Optional[tuple[int, int, datetime.datetime]]]:
156
+ assert requests or tokens
157
+ now = datetime.datetime.now(tz=datetime.timezone.utc)
158
+
159
+ requests_info: Optional[tuple[int, int, datetime.datetime]] = None
160
+ if requests:
161
+ requests_limit_str = headers.get('x-ratelimit-limit-requests')
162
+ requests_limit = int(requests_limit_str) if requests_limit_str is not None else None
163
+ requests_remaining_str = headers.get('x-ratelimit-remaining-requests')
164
+ requests_remaining = int(requests_remaining_str) if requests_remaining_str is not None else None
165
+ requests_reset_str = headers.get('x-ratelimit-reset-requests')
166
+ requests_reset_ts = now + _parse_header_duration(requests_reset_str)
167
+ requests_info = (requests_limit, requests_remaining, requests_reset_ts)
168
+
169
+ tokens_info: Optional[tuple[int, int, datetime.datetime]] = None
170
+ if tokens:
171
+ tokens_limit_str = headers.get('x-ratelimit-limit-tokens')
172
+ tokens_limit = int(tokens_limit_str) if tokens_limit_str is not None else None
173
+ tokens_remaining_str = headers.get('x-ratelimit-remaining-tokens')
174
+ tokens_remaining = int(tokens_remaining_str) if tokens_remaining_str is not None else None
175
+ tokens_reset_str = headers.get('x-ratelimit-reset-tokens')
176
+ tokens_reset_ts = now + _parse_header_duration(tokens_reset_str)
177
+ tokens_info = (tokens_limit, tokens_remaining, tokens_reset_ts)
178
+
179
+ return requests_info, tokens_info
180
+
181
+
50
182
  #####################################
51
183
  # Audio Endpoints
52
184
 
53
185
 
54
186
  @pxt.udf
55
187
  def speech(
56
- input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
188
+ input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
57
189
  ) -> pxt.Audio:
58
190
  """
59
191
  Generates audio from the input text.
@@ -176,8 +308,24 @@ def translations(
176
308
  # Chat Endpoints
177
309
 
178
310
 
311
+ def _chat_completions_get_request_resources(
312
+ messages: list, max_tokens: Optional[int], n: Optional[int]
313
+ ) -> dict[str, int]:
314
+ completion_tokens = n * max_tokens
315
+
316
+ num_tokens = 0.0
317
+ for message in messages:
318
+ num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
319
+ for key, value in message.items():
320
+ num_tokens += len(value) / 4
321
+ if key == "name": # if there's a name, the role is omitted
322
+ num_tokens -= 1 # role is always required and always 1 token
323
+ num_tokens += 2 # every reply is primed with <im_start>assistant
324
+ return {'requests': 1, 'tokens': int(num_tokens) + completion_tokens}
325
+
326
+
179
327
  @pxt.udf
180
- def chat_completions(
328
+ async def chat_completions(
181
329
  messages: list,
182
330
  *,
183
331
  model: str,
@@ -185,8 +333,8 @@ def chat_completions(
185
333
  logit_bias: Optional[dict[str, int]] = None,
186
334
  logprobs: Optional[bool] = None,
187
335
  top_logprobs: Optional[int] = None,
188
- max_tokens: Optional[int] = None,
189
- n: Optional[int] = None,
336
+ max_tokens: Optional[int] = 1024,
337
+ n: Optional[int] = 1,
190
338
  presence_penalty: Optional[float] = None,
191
339
  response_format: Optional[dict] = None,
192
340
  seed: Optional[int] = None,
@@ -226,7 +374,6 @@ def chat_completions(
226
374
  ]
227
375
  tbl['response'] = chat_completions(messages, model='gpt-4o-mini')
228
376
  """
229
-
230
377
  if tools is not None:
231
378
  tools = [
232
379
  {
@@ -236,7 +383,25 @@ def chat_completions(
236
383
  for tool in tools
237
384
  ]
238
385
 
239
- result = _retry(_openai_client().chat.completions.create)(
386
+ tool_choice_: Union[str, dict, None] = None
387
+ if tool_choice is not None:
388
+ if tool_choice['auto']:
389
+ tool_choice_ = 'auto'
390
+ elif tool_choice['required']:
391
+ tool_choice_ = 'required'
392
+ else:
393
+ assert tool_choice['tool'] is not None
394
+ tool_choice_ = {
395
+ 'type': 'function',
396
+ 'function': {'name': tool_choice['tool']}
397
+ }
398
+
399
+ extra_body: Optional[dict[str, Any]] = None
400
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
401
+ extra_body = {'parallel_tool_calls': False}
402
+
403
+ # cast(Any, ...): avoid mypy errors
404
+ result = await _async_openai_client().chat.completions.with_raw_response.create(
240
405
  messages=messages,
241
406
  model=model,
242
407
  frequency_penalty=_opt(frequency_penalty),
@@ -246,16 +411,25 @@ def chat_completions(
246
411
  max_tokens=_opt(max_tokens),
247
412
  n=_opt(n),
248
413
  presence_penalty=_opt(presence_penalty),
249
- response_format=_opt(response_format),
414
+ response_format=_opt(cast(Any, response_format)),
250
415
  seed=_opt(seed),
251
416
  stop=_opt(stop),
252
417
  temperature=_opt(temperature),
253
418
  top_p=_opt(top_p),
254
- tools=_opt(tools),
255
- tool_choice=_opt(tool_choice),
419
+ tools=_opt(cast(Any, tools)),
420
+ tool_choice=_opt(cast(Any, tool_choice_)),
256
421
  user=_opt(user),
422
+ timeout=10,
423
+ extra_body=extra_body,
257
424
  )
258
- return result.dict()
425
+
426
+ resource_pool = _resource_pool(model)
427
+ requests_info, tokens_info = _get_header_info(result.headers)
428
+ rate_limits_info = env.Env.get().get_resource_pool_info(resource_pool, lambda: OpenAIRateLimitsInfo(
429
+ _chat_completions_get_request_resources))
430
+ rate_limits_info.record(requests=requests_info, tokens=tokens_info)
431
+
432
+ return json.loads(result.text)
259
433
 
260
434
 
261
435
  @pxt.udf
@@ -312,8 +486,13 @@ _embedding_dimensions_cache: dict[str, int] = {
312
486
  }
313
487
 
314
488
 
489
+ def _embeddings_get_request_resources(input: list[str]) -> dict[str, int]:
490
+ input_len = sum(len(s) for s in input)
491
+ return {'requests': 1, 'tokens': int(input_len / 4)}
492
+
493
+
315
494
  @pxt.udf(batch_size=32)
316
- def embeddings(
495
+ async def embeddings(
317
496
  input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
318
497
  ) -> Batch[pxt.Array[(None,), pxt.Float]]:
319
498
  """
@@ -343,10 +522,16 @@ def embeddings(
343
522
 
344
523
  >>> tbl['embed'] = embeddings(tbl.text, model='text-embedding-3-small')
345
524
  """
346
- result = _retry(_openai_client().embeddings.create)(
525
+ _logger.debug(f'embeddings: batch_size={len(input)}')
526
+ result = await _async_openai_client().embeddings.with_raw_response.create(
347
527
  input=input, model=model, dimensions=_opt(dimensions), user=_opt(user), encoding_format='float'
348
528
  )
349
- return [np.array(data.embedding, dtype=np.float64) for data in result.data]
529
+ resource_pool = _resource_pool(model)
530
+ requests_info, tokens_info = _get_header_info(result.headers)
531
+ rate_limits_info = env.Env.get().get_resource_pool_info(
532
+ resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources))
533
+ rate_limits_info.record(requests=requests_info, tokens=tokens_info)
534
+ return [np.array(data['embedding'], dtype=np.float64) for data in json.loads(result.content)['data']]
350
535
 
351
536
 
352
537
  @embeddings.conditional_return_type
@@ -367,7 +552,7 @@ def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
367
552
  def image_generations(
368
553
  prompt: str,
369
554
  *,
370
- model: Optional[str] = None,
555
+ model: str = 'dall-e-2',
371
556
  quality: Optional[str] = None,
372
557
  size: Optional[str] = None,
373
558
  style: Optional[str] = None,
@@ -423,7 +608,7 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
423
608
  if x_pos == -1:
424
609
  return pxt.ImageType()
425
610
  try:
426
- width, height = int(size[:x_pos]), int(size[x_pos + 1 :])
611
+ width, height = int(size[:x_pos]), int(size[x_pos + 1:])
427
612
  except ValueError:
428
613
  return pxt.ImageType()
429
614
  return pxt.ImageType(size=(width, height))
@@ -434,7 +619,7 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
434
619
 
435
620
 
436
621
  @pxt.udf
437
- def moderations(input: str, *, model: Optional[str] = None) -> dict:
622
+ def moderations(input: str, *, model: str = 'omni-moderation-latest') -> dict:
438
623
  """
439
624
  Classifies if text is potentially harmful.
440
625
 
@@ -464,6 +649,18 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
464
649
  return result.dict()
465
650
 
466
651
 
652
+ # @speech.resource_pool
653
+ # @transcriptions.resource_pool
654
+ # @translations.resource_pool
655
+ @chat_completions.resource_pool
656
+ # @vision.resource_pool
657
+ @embeddings.resource_pool
658
+ # @image_generations.resource_pool
659
+ # @moderations.resource_pool
660
+ def _(model: str) -> str:
661
+ return _resource_pool(model)
662
+
663
+
467
664
  def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
468
665
  """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
469
666
  return tools._invoke(_openai_response_to_pxt_tool_calls(response))
@@ -471,15 +668,15 @@ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
471
668
 
472
669
  @pxt.udf
473
670
  def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
671
+ if 'tool_calls' not in response['choices'][0]['message'] or response['choices'][0]['message']['tool_calls'] is None:
672
+ return None
474
673
  openai_tool_calls = response['choices'][0]['message']['tool_calls']
475
- if openai_tool_calls is not None:
476
- return {
477
- tool_call['function']['name']: {
478
- 'args': json.loads(tool_call['function']['arguments'])
479
- }
480
- for tool_call in openai_tool_calls
674
+ return {
675
+ tool_call['function']['name']: {
676
+ 'args': json.loads(tool_call['function']['arguments'])
481
677
  }
482
- return None
678
+ for tool_call in openai_tool_calls
679
+ }
483
680
 
484
681
 
485
682
  _T = TypeVar('_T')
pixeltable/globals.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import dataclasses
2
2
  import logging
3
- from typing import Any, Iterable, Optional, Union, Literal, Type
3
+ from typing import Any, Iterable, Literal, Optional, Union
4
4
  from uuid import UUID
5
5
 
6
6
  import pandas as pd
@@ -26,7 +26,7 @@ def init() -> None:
26
26
 
27
27
  def _get_or_drop_existing_path(
28
28
  path_str: str,
29
- expected_obj_type: Type[catalog.SchemaObject],
29
+ expected_obj_type: type[catalog.SchemaObject],
30
30
  expected_snapshot: bool,
31
31
  if_exists: catalog.IfExistsParam
32
32
  ) -> Optional[catalog.SchemaObject]:
@@ -289,6 +289,11 @@ def create_view(
289
289
 
290
290
  if additional_columns is None:
291
291
  additional_columns = {}
292
+ else:
293
+ # additional columns should not be in the base table
294
+ for col_name in additional_columns.keys():
295
+ if col_name in [c.name for c in tbl_version_path.columns()]:
296
+ raise excs.Error(f"Column {col_name!r} already exists in the base table {tbl_version_path.get_column(col_name).tbl.name}.")
292
297
  if iterator is None:
293
298
  iterator_class, iterator_args = None, None
294
299
  else:
@@ -787,8 +792,6 @@ def tool(fn: func.Function, name: Optional[str] = None, description: Optional[st
787
792
  Returns:
788
793
  A `Tool` instance that can be passed to an LLM tool-calling API.
789
794
  """
790
- if fn.self_path is None:
791
- raise excs.Error('Only module UDFs can be used as tools (not locally defined UDFs)')
792
795
  if isinstance(fn, func.AggregateFunction):
793
796
  raise excs.Error('Aggregator UDFs cannot be used as tools')
794
797
 
@@ -814,4 +817,4 @@ def configure_logging(
814
817
 
815
818
 
816
819
  def array(elements: Iterable) -> exprs.Expr:
817
- return exprs.InlineArray(elements)
820
+ return exprs.Expr.from_array(elements)
@@ -46,32 +46,79 @@ class EmbeddingIndex(IndexBase):
46
46
  index_col_type: pgvector.sqlalchemy.Vector
47
47
 
48
48
  def __init__(
49
- self, c: catalog.Column, metric: str, string_embed: Optional[func.Function] = None,
50
- image_embed: Optional[func.Function] = None):
49
+ self,
50
+ c: catalog.Column,
51
+ metric: str,
52
+ embed: Optional[func.Function] = None,
53
+ string_embed: Optional[func.Function] = None,
54
+ image_embed: Optional[func.Function] = None,
55
+ ):
56
+ if embed is None and string_embed is None and image_embed is None:
57
+ raise excs.Error('At least one of `embed`, `string_embed`, or `image_embed` must be specified')
51
58
  metric_names = [m.name.lower() for m in self.Metric]
52
59
  if metric.lower() not in metric_names:
53
60
  raise excs.Error(f'Invalid metric {metric}, must be one of {metric_names}')
54
61
  if not c.col_type.is_string_type() and not c.col_type.is_image_type():
55
62
  raise excs.Error(f'Embedding index requires string or image column')
56
- if c.col_type.is_string_type() and string_embed is None:
57
- raise excs.Error(f"Text embedding function is required for column {c.name} (parameter 'string_embed')")
58
- if c.col_type.is_image_type() and image_embed is None:
59
- raise excs.Error(f"Image embedding function is required for column {c.name} (parameter 'image_embed')")
60
-
61
- if string_embed is None:
62
- self.string_embed = None
63
- else:
64
- # verify signature and convert to a monomorphic function
65
- self.string_embed = self._validate_embedding_fn(string_embed, 'string_embed', ts.ColumnType.Type.STRING)
66
63
 
67
- if image_embed is None:
68
- self.image_embed = None
69
- else:
70
- # verify signature and convert to a monomorphic function
71
- self.image_embed = self._validate_embedding_fn(image_embed, 'image_embed', ts.ColumnType.Type.IMAGE)
64
+ self.string_embed = None
65
+ self.image_embed = None
66
+
67
+ # Resolve the specific embedding functions corresponding to the user-provided `string_embed`, `image_embed`,
68
+ # and/or `embed`. For string embeddings, `string_embed` will be used if specified; otherwise, `embed` will
69
+ # be used as a fallback, if it has a matching signature. Likewise for image embeddings.
70
+
71
+ if string_embed is not None:
72
+ # `string_embed` is specified; it MUST be valid.
73
+ self.string_embed = self._resolve_embedding_fn(string_embed, ts.ColumnType.Type.STRING)
74
+ if self.string_embed is None:
75
+ raise excs.Error(
76
+ f'The function `{string_embed.name}` is not a valid string embedding: '
77
+ 'it must take a single string parameter'
78
+ )
79
+ elif embed is not None:
80
+ # `embed` is specified; see if it has a string signature.
81
+ self.string_embed = self._resolve_embedding_fn(embed, ts.ColumnType.Type.STRING)
82
+
83
+ if image_embed is not None:
84
+ # `image_embed` is specified; it MUST be valid.
85
+ self.image_embed = self._resolve_embedding_fn(image_embed, ts.ColumnType.Type.IMAGE)
86
+ if self.image_embed is None:
87
+ raise excs.Error(
88
+ f'The function `{image_embed.name}` is not a valid image embedding: '
89
+ 'it must take a single image parameter'
90
+ )
91
+ elif embed is not None:
92
+ # `embed` is specified; see if it has an image signature.
93
+ self.image_embed = self._resolve_embedding_fn(embed, ts.ColumnType.Type.IMAGE)
94
+
95
+ if self.string_embed is None and self.image_embed is None:
96
+ # No string OR image signature was found. This can only happen if `embed` was specified and
97
+ # contains no matching signatures.
98
+ assert embed is not None
99
+ raise excs.Error(
100
+ f'The function `{embed.name}` is not a valid embedding: '
101
+ 'it must take a single string or image parameter'
102
+ )
103
+
104
+ # Now validate the return types of the embedding functions.
105
+
106
+ if self.string_embed is not None:
107
+ self._validate_embedding_fn(self.string_embed, ts.ColumnType.Type.STRING)
108
+
109
+ if self.image_embed is not None:
110
+ self._validate_embedding_fn(self.image_embed, ts.ColumnType.Type.IMAGE)
111
+
112
+ if c.col_type.is_string_type() and self.string_embed is None:
113
+ raise excs.Error(f"Text embedding function is required for column {c.name} (parameter 'string_embed')")
114
+ if c.col_type.is_image_type() and self.image_embed is None:
115
+ raise excs.Error(f"Image embedding function is required for column {c.name} (parameter 'image_embed')")
72
116
 
73
117
  self.metric = self.Metric[metric.upper()]
74
- self.value_expr = string_embed(exprs.ColumnRef(c)) if c.col_type.is_string_type() else image_embed(exprs.ColumnRef(c))
118
+ self.value_expr = (
119
+ self.string_embed(exprs.ColumnRef(c)) if c.col_type.is_string_type()
120
+ else self.image_embed(exprs.ColumnRef(c))
121
+ )
75
122
  assert isinstance(self.value_expr.col_type, ts.ArrayType)
76
123
  vector_size = self.value_expr.col_type.shape[0]
77
124
  assert vector_size is not None
@@ -144,42 +191,47 @@ class EmbeddingIndex(IndexBase):
144
191
  return 'embedding'
145
192
 
146
193
  @classmethod
147
- def _validate_embedding_fn(cls, embed_fn: func.Function, name: str, expected_type: ts.ColumnType.Type) -> func.Function:
148
- """Validate that the Function has a matching signature, and return the corresponding monomorphic function."""
194
+ def _resolve_embedding_fn(cls, embed_fn: func.Function, expected_type: ts.ColumnType.Type) -> Optional[func.Function]:
195
+ """Find an overload resolution for `embed_fn` that matches the given type."""
149
196
  assert isinstance(embed_fn, func.Function)
150
-
151
- signature_idx: int = -1
152
- for idx, sig in enumerate(embed_fn.signatures):
197
+ for resolved_fn in embed_fn._resolved_fns:
153
198
  # The embedding function must be a 1-ary function of the correct type. But it's ok if the function signature
154
199
  # has more than one parameter, as long as it has at most one *required* parameter.
200
+ sig = resolved_fn.signature
155
201
  if (len(sig.parameters) >= 1
156
202
  and len(sig.required_parameters) <= 1
157
203
  and sig.parameters_by_pos[0].col_type.type_enum == expected_type):
158
- signature_idx = idx
159
- break
160
-
161
- if signature_idx == -1:
162
- raise excs.Error(f'{name} must take a single {expected_type.name.lower()} parameter')
204
+ return resolved_fn
205
+ return None
163
206
 
164
- resolved_fn = embed_fn._resolved_fns[signature_idx]
207
+ @classmethod
208
+ def _validate_embedding_fn(cls, embed_fn: func.Function, expected_type: ts.ColumnType.Type) -> None:
209
+ """Validate the given embedding function."""
210
+ assert not embed_fn.is_polymorphic
211
+ sig = embed_fn.signature
165
212
 
166
213
  # validate return type
167
214
  param_name = sig.parameters_by_pos[0].name
168
215
  if expected_type == ts.ColumnType.Type.STRING:
169
- return_type = resolved_fn.call_return_type([], {param_name: 'dummy'})
216
+ return_type = embed_fn.call_return_type([], {param_name: 'dummy'})
170
217
  else:
171
218
  assert expected_type == ts.ColumnType.Type.IMAGE
172
219
  img = PIL.Image.new('RGB', (512, 512))
173
- return_type = resolved_fn.call_return_type([], {param_name: img})
220
+ return_type = embed_fn.call_return_type([], {param_name: img})
221
+
174
222
  assert return_type is not None
175
223
  if not isinstance(return_type, ts.ArrayType):
176
- raise excs.Error(f'{name} must return an array, but returns {return_type}')
177
- else:
178
- shape = return_type.shape
179
- if len(shape) != 1 or shape[0] == None:
180
- raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {return_type}')
181
-
182
- return resolved_fn
224
+ raise excs.Error(
225
+ f'The function `{embed_fn.name}` is not a valid embedding: '
226
+ f'it must return an array, but returns {return_type}'
227
+ )
228
+
229
+ shape = return_type.shape
230
+ if len(shape) != 1 or shape[0] == None:
231
+ raise excs.Error(
232
+ f'The function `{embed_fn.name}` is not a valid embedding: '
233
+ f'it must return a 1-dimensional array of a specific length, but returns {return_type}'
234
+ )
183
235
 
184
236
  def as_dict(self) -> dict:
185
237
  return {
@@ -574,7 +574,7 @@ class LabelStudioProject(Project):
574
574
  else:
575
575
  local_annotations_column = next(k for k, v in col_mapping.items() if v == ANNOTATIONS_COLUMN)
576
576
  if local_annotations_column not in t._schema.keys():
577
- t[local_annotations_column] = pxt.JsonType(nullable=True)
577
+ t.add_columns({local_annotations_column: pxt.JsonType(nullable=True)})
578
578
 
579
579
  resolved_col_mapping = cls.validate_columns(
580
580
  t, config.export_columns, {ANNOTATIONS_COLUMN: pxt.JsonType(nullable=True)}, col_mapping)
@@ -10,7 +10,7 @@ import sqlalchemy.orm as orm
10
10
  from .schema import SystemInfo, SystemInfoMd
11
11
 
12
12
  # current version of the metadata; this is incremented whenever the metadata schema changes
13
- VERSION = 25
13
+ VERSION = 26
14
14
 
15
15
 
16
16
  def create_system_info(engine: sql.engine.Engine) -> None: