docent-python 0.1.20a0__py3-none-any.whl → 0.1.21a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

@@ -0,0 +1,745 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Any, Literal, cast
4
+
5
+ import backoff
6
+ import tiktoken
7
+ from backoff.types import Details
8
+
9
+ # all errors: https://platform.openai.com/docs/guides/error-codes/api-errors#python-library-error-types
10
+ from openai import (
11
+ APIConnectionError,
12
+ AsyncAzureOpenAI,
13
+ AsyncOpenAI,
14
+ AuthenticationError,
15
+ BadRequestError,
16
+ NotFoundError,
17
+ OpenAI,
18
+ PermissionDeniedError,
19
+ RateLimitError,
20
+ UnprocessableEntityError,
21
+ omit,
22
+ )
23
+ from openai.types.chat import (
24
+ ChatCompletion,
25
+ ChatCompletionAssistantMessageParam,
26
+ ChatCompletionChunk,
27
+ ChatCompletionContentPartTextParam,
28
+ ChatCompletionMessageParam,
29
+ ChatCompletionMessageToolCallParam,
30
+ ChatCompletionMessageToolCallUnion,
31
+ ChatCompletionSystemMessageParam,
32
+ ChatCompletionToolMessageParam,
33
+ ChatCompletionToolParam,
34
+ ChatCompletionUserMessageParam,
35
+ )
36
+ from openai.types.chat.chat_completion_message_tool_call_param import (
37
+ Function as OpenAIFunctionParam,
38
+ )
39
+ from openai.types.shared_params.function_definition import FunctionDefinition
40
+
41
+ from docent._llm_util.data_models.exceptions import (
42
+ CompletionTooLongException,
43
+ ContextWindowException,
44
+ NoResponseException,
45
+ RateLimitException,
46
+ )
47
+ from docent._llm_util.data_models.llm_output import (
48
+ AsyncEmbeddingStreamingCallback,
49
+ AsyncSingleLLMOutputStreamingCallback,
50
+ FinishReasonType,
51
+ LLMCompletion,
52
+ LLMCompletionPartial,
53
+ LLMOutput,
54
+ LLMOutputPartial,
55
+ ToolCallPartial,
56
+ UsageMetrics,
57
+ finalize_llm_output_partial,
58
+ )
59
+ from docent._llm_util.providers.common import async_timeout_ctx
60
+ from docent._log_util import get_logger
61
+ from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
62
+
63
+ logger = get_logger(__name__)
64
+ DEFAULT_TIKTOKEN_ENCODING = "cl100k_base"
65
+ MAX_EMBEDDING_TOKENS = 8000
66
+
67
+
68
+ def _print_backoff_message(e: Details):
69
+ logger.warning(
70
+ f"OpenAI backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
71
+ )
72
+
73
+
74
+ def _is_retryable_error(e: BaseException) -> bool:
75
+ if (
76
+ isinstance(e, BadRequestError)
77
+ or isinstance(e, ContextWindowException)
78
+ or isinstance(e, AuthenticationError)
79
+ or isinstance(e, PermissionDeniedError)
80
+ or isinstance(e, NotFoundError)
81
+ or isinstance(e, UnprocessableEntityError)
82
+ or isinstance(e, APIConnectionError)
83
+ ):
84
+ return False
85
+ return True
86
+
87
+
88
+ def _parse_message_content(
89
+ content: str | list[Content],
90
+ ) -> str | list[ChatCompletionContentPartTextParam]:
91
+ if isinstance(content, str):
92
+ return content
93
+ else:
94
+ result: list[ChatCompletionContentPartTextParam] = []
95
+ for sub_content in content:
96
+ if sub_content.type == "text":
97
+ result.append(
98
+ ChatCompletionContentPartTextParam(type="text", text=sub_content.text)
99
+ )
100
+ else:
101
+ raise ValueError(f"Unsupported content type: {sub_content.type}")
102
+ return result
103
+
104
+
105
+ def parse_chat_messages(messages: list[ChatMessage]) -> list[ChatCompletionMessageParam]:
106
+ result: list[ChatCompletionMessageParam] = []
107
+
108
+ for message in messages:
109
+ if message.role == "user":
110
+ result.append(
111
+ ChatCompletionUserMessageParam(
112
+ role=message.role,
113
+ content=_parse_message_content(message.content),
114
+ )
115
+ )
116
+ elif message.role == "assistant":
117
+ tool_calls = (
118
+ [
119
+ ChatCompletionMessageToolCallParam(
120
+ id=tool_call.id,
121
+ function=OpenAIFunctionParam(
122
+ name=tool_call.function,
123
+ arguments=json.dumps(tool_call.arguments),
124
+ ),
125
+ type="function",
126
+ )
127
+ for tool_call in message.tool_calls
128
+ ]
129
+ if message.tool_calls
130
+ else None
131
+ )
132
+ # Redundant code annoyingly necessary due to typechecking, but maybe I'm missing something
133
+ if not tool_calls:
134
+ result.append(
135
+ ChatCompletionAssistantMessageParam(
136
+ role=message.role, content=_parse_message_content(message.content)
137
+ )
138
+ )
139
+ else:
140
+ result.append(
141
+ ChatCompletionAssistantMessageParam(
142
+ role=message.role,
143
+ content=_parse_message_content(message.content),
144
+ tool_calls=tool_calls,
145
+ )
146
+ )
147
+ elif message.role == "tool":
148
+ result.append(
149
+ ChatCompletionToolMessageParam(
150
+ role=message.role,
151
+ content=_parse_message_content(message.content),
152
+ tool_call_id=str(message.tool_call_id),
153
+ )
154
+ )
155
+ elif message.role == "system":
156
+ result.append(
157
+ ChatCompletionSystemMessageParam(
158
+ role=message.role,
159
+ content=_parse_message_content(message.content),
160
+ )
161
+ )
162
+
163
+ return result
164
+
165
+
166
+ def parse_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
167
+ """Convert ToolInfo objects to OpenAI ChatCompletionToolParam format."""
168
+
169
+ result: list[ChatCompletionToolParam] = []
170
+
171
+ for tool in tools:
172
+ result.append(
173
+ ChatCompletionToolParam(
174
+ type="function",
175
+ function=FunctionDefinition(
176
+ name=tool.name,
177
+ description=tool.description,
178
+ parameters=tool.parameters.model_dump(exclude_none=True),
179
+ ),
180
+ )
181
+ )
182
+
183
+ return result
184
+
185
+
186
+ @backoff.on_exception(
187
+ backoff.expo,
188
+ exception=(Exception,),
189
+ giveup=lambda e: not _is_retryable_error(e),
190
+ max_tries=5,
191
+ factor=3.0,
192
+ on_backoff=_print_backoff_message,
193
+ )
194
+ async def get_openai_chat_completion_streaming_async(
195
+ client: AsyncOpenAI,
196
+ streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
197
+ messages: list[ChatMessage],
198
+ model_name: str,
199
+ tools: list[ToolInfo] | None = None,
200
+ tool_choice: Literal["auto", "required"] | None = None,
201
+ max_new_tokens: int = 32,
202
+ temperature: float = 1.0,
203
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
204
+ logprobs: bool = False,
205
+ top_logprobs: int | None = None,
206
+ timeout: float = 30.0,
207
+ ):
208
+ input_messages = parse_chat_messages(messages)
209
+ input_tools = parse_tools(tools) if tools else omit
210
+
211
+ try:
212
+ async with async_timeout_ctx(timeout):
213
+ stream = await client.chat.completions.create(
214
+ model=model_name,
215
+ messages=input_messages,
216
+ tools=input_tools,
217
+ tool_choice=tool_choice or omit,
218
+ max_completion_tokens=max_new_tokens,
219
+ temperature=temperature,
220
+ reasoning_effort=reasoning_effort or omit,
221
+ logprobs=logprobs,
222
+ top_logprobs=top_logprobs,
223
+ stream_options={"include_usage": True},
224
+ stream=True,
225
+ )
226
+
227
+ llm_output_partial = None
228
+ async for chunk in stream:
229
+ llm_output_partial = update_llm_output(llm_output_partial, chunk)
230
+ if streaming_callback:
231
+ await streaming_callback(finalize_llm_output_partial(llm_output_partial))
232
+
233
+ # Fully parse the partial output
234
+ if llm_output_partial:
235
+ return finalize_llm_output_partial(llm_output_partial)
236
+ else:
237
+ # Streaming did not produce anything
238
+ return LLMOutput(model=model_name, completions=[], errors=[NoResponseException()])
239
+ except (RateLimitError, BadRequestError) as e:
240
+ if e2 := _convert_openai_error(e):
241
+ raise e2 from e
242
+ else:
243
+ raise
244
+
245
+
246
+ def _convert_openai_error(e: Exception):
247
+ if isinstance(e, RateLimitError):
248
+ return RateLimitException(e)
249
+ elif isinstance(e, BadRequestError) and e.code == "context_length_exceeded":
250
+ return ContextWindowException()
251
+ return None
252
+
253
+
254
+ def update_llm_output(llm_output_partial: LLMOutputPartial | None, chunk: ChatCompletionChunk):
255
+ # Collect exisitng outputs
256
+ if llm_output_partial is not None:
257
+ cur_texts: list[str | None] = [c.text for c in llm_output_partial.completions]
258
+ cur_finish_reasons: list[FinishReasonType | None] = [
259
+ c.finish_reason for c in llm_output_partial.completions
260
+ ]
261
+ cur_tool_calls_all: list[list[ToolCallPartial | None] | None] = [
262
+ cast(list[ToolCallPartial | None], c.tool_calls) for c in llm_output_partial.completions
263
+ ]
264
+ else:
265
+ cur_texts, cur_finish_reasons, cur_tool_calls_all = [], [], []
266
+
267
+ # Define functions for getting and setting values of the current state
268
+ def _get_text(i: int):
269
+ if i >= len(cur_texts):
270
+ return None
271
+ else:
272
+ return cur_texts[i]
273
+
274
+ def _set_text(i: int, text: str):
275
+ if i >= len(cur_texts):
276
+ cur_texts.extend([None] * (i - len(cur_texts) + 1))
277
+ cur_texts[i] = text
278
+
279
+ def _get_finish_reason(i: int):
280
+ if i >= len(cur_finish_reasons) or cur_finish_reasons[i] is None:
281
+ return None
282
+ else:
283
+ return cur_finish_reasons[i]
284
+
285
+ def _set_finish_reason(i: int, finish_reason: FinishReasonType | None):
286
+ if i >= len(cur_finish_reasons):
287
+ cur_finish_reasons.extend([None] * (i - len(cur_finish_reasons) + 1))
288
+ cur_finish_reasons[i] = finish_reason
289
+
290
+ def _get_tool_calls(i: int):
291
+ if i >= len(cur_tool_calls_all):
292
+ return None
293
+ else:
294
+ return cur_tool_calls_all[i]
295
+
296
+ def _get_tool_call(i: int, j: int):
297
+ if i >= len(cur_tool_calls_all):
298
+ return None
299
+ else:
300
+ cur_tool_calls = cur_tool_calls_all[i]
301
+ if cur_tool_calls is None or j >= len(cur_tool_calls):
302
+ return None
303
+ else:
304
+ return cur_tool_calls[j]
305
+
306
+ def _set_tool_call(i: int, j: int, tool_call: ToolCallPartial):
307
+ if i >= len(cur_tool_calls_all):
308
+ cur_tool_calls_all.extend([None] * (i - len(cur_tool_calls_all) + 1))
309
+
310
+ # Add ToolCall to current choice index
311
+ cur_tool_calls = cur_tool_calls_all[i] or []
312
+ if j >= len(cur_tool_calls):
313
+ cur_tool_calls.extend([None] * (j - len(cur_tool_calls) + 1))
314
+ cur_tool_calls[j] = tool_call
315
+
316
+ # Re-update the global array
317
+ cur_tool_calls_all[i] = cur_tool_calls
318
+
319
+ # Update existing completions based on this chunk
320
+ for choice in chunk.choices:
321
+ i, delta = choice.index, choice.delta
322
+
323
+ # Resolve text and finish reason
324
+ _set_text(i, (_get_text(i) or "") + (delta.content or ""))
325
+ _set_finish_reason(i, choice.finish_reason or _get_finish_reason(i))
326
+
327
+ # Tool call resolution is more complicated
328
+ for tc_delta in delta.tool_calls or []:
329
+ tc_idx = tc_delta.index
330
+ tc_function = tc_delta.function.name if tc_delta.function else None
331
+ tc_arguments = tc_delta.function.arguments if tc_delta.function else None
332
+
333
+ old_tool_call = _get_tool_call(i, tc_idx)
334
+
335
+ if old_tool_call:
336
+ tool_call_partial = ToolCallPartial(
337
+ id=old_tool_call.id or tc_delta.id,
338
+ function=(old_tool_call.function or "") + (tc_function or ""),
339
+ arguments_raw=(old_tool_call.arguments_raw or "") + (tc_arguments or ""),
340
+ type="function",
341
+ )
342
+ else:
343
+ tool_call_partial = ToolCallPartial(
344
+ id=tc_delta.id,
345
+ function=tc_function or "",
346
+ arguments_raw=tc_arguments or "",
347
+ type="function",
348
+ )
349
+
350
+ _set_tool_call(i, tc_idx, tool_call_partial)
351
+
352
+ if chunk.usage is not None:
353
+ usage = UsageMetrics(
354
+ input=chunk.usage.prompt_tokens,
355
+ output=chunk.usage.completion_tokens,
356
+ )
357
+ else:
358
+ usage = UsageMetrics()
359
+
360
+ completions: list[LLMCompletionPartial] = []
361
+ # TOOD assert all lengths are same
362
+ for i in range(len(cur_texts)):
363
+ completions.append(
364
+ LLMCompletionPartial(
365
+ text=_get_text(i),
366
+ tool_calls=_get_tool_calls(i),
367
+ finish_reason=_get_finish_reason(i),
368
+ )
369
+ )
370
+
371
+ return LLMOutputPartial(
372
+ completions=completions, # type: ignore[arg-type]
373
+ model=chunk.model,
374
+ usage=usage,
375
+ )
376
+
377
+
378
+ @backoff.on_exception(
379
+ backoff.expo,
380
+ exception=(Exception,),
381
+ giveup=lambda e: not _is_retryable_error(e),
382
+ max_tries=5,
383
+ factor=3.0,
384
+ on_backoff=_print_backoff_message,
385
+ )
386
+ async def get_openai_chat_completion_async(
387
+ client: AsyncOpenAI,
388
+ messages: list[ChatMessage],
389
+ model_name: str,
390
+ tools: list[ToolInfo] | None = None,
391
+ tool_choice: Literal["auto", "none", "required"] | None = None,
392
+ max_new_tokens: int = 32,
393
+ temperature: float = 1.0,
394
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
395
+ logprobs: bool = False,
396
+ top_logprobs: int | None = None,
397
+ timeout: float = 5.0,
398
+ ) -> LLMOutput:
399
+ input_messages = parse_chat_messages(messages)
400
+ input_tools = parse_tools(tools) if tools else omit
401
+
402
+ try:
403
+ async with async_timeout_ctx(timeout): # type: ignore
404
+ raw_output = await client.chat.completions.create(
405
+ model=model_name,
406
+ messages=input_messages,
407
+ tools=input_tools,
408
+ tool_choice=tool_choice or omit,
409
+ max_completion_tokens=max_new_tokens,
410
+ temperature=temperature,
411
+ reasoning_effort=reasoning_effort or omit,
412
+ logprobs=logprobs,
413
+ top_logprobs=top_logprobs,
414
+ )
415
+
416
+ # If the completion is empty and was truncated (likely due to too much reasoning), raise an exception
417
+ output = parse_openai_completion(raw_output, model_name)
418
+ if output.first and output.first.finish_reason == "length" and output.first.no_text:
419
+ raise CompletionTooLongException(
420
+ "Completion empty due to truncation. Consider increasing max_new_tokens."
421
+ )
422
+ for c in output.completions:
423
+ if c.finish_reason == "length":
424
+ logger.warning(
425
+ "Completion truncated due to length; consider increasing max_new_tokens."
426
+ )
427
+
428
+ return output
429
+ except (RateLimitError, BadRequestError) as e:
430
+ if e2 := _convert_openai_error(e):
431
+ raise e2 from e
432
+ else:
433
+ raise
434
+
435
+
436
+ def get_openai_client_async(api_key: str | None = None) -> AsyncOpenAI:
437
+ return AsyncOpenAI(api_key=api_key) if api_key else AsyncOpenAI()
438
+
439
+
440
+ def get_azure_openai_client_async(api_key: str | None = None) -> AsyncAzureOpenAI:
441
+ return AsyncAzureOpenAI(api_key=api_key) if api_key else AsyncAzureOpenAI()
442
+
443
+
444
+ def chunk_and_tokenize(
445
+ text: list[str],
446
+ window_size: int = 8191,
447
+ overlap: int = 128,
448
+ ) -> tuple[list[list[int]], list[int]]:
449
+ """Encode a list of text into a list of token ids."""
450
+
451
+ def _chunk_tokens(tokens: list[int], window_size: int, overlap: int) -> list[list[int]]:
452
+ """Compute list chunks with overlap."""
453
+ if overlap >= window_size:
454
+ raise ValueError("overlap must be smaller than window_size")
455
+
456
+ stride = window_size - overlap
457
+ chunks: list[list[int]] = []
458
+ for i in range(0, len(tokens), stride):
459
+ chunks.append(tokens[i : i + window_size])
460
+ return chunks
461
+
462
+ encoding = tiktoken.get_encoding(DEFAULT_TIKTOKEN_ENCODING)
463
+
464
+ all_chunks: list[list[int]] = []
465
+ chunk_to_doc: list[int] = []
466
+
467
+ for i, item in enumerate(text):
468
+ tokens = encoding.encode(item)
469
+ if len(tokens) <= window_size:
470
+ chunks = [tokens]
471
+ else:
472
+ chunks = _chunk_tokens(tokens, window_size, overlap)
473
+
474
+ all_chunks.extend(chunks)
475
+ chunk_to_doc.extend([i] * len(chunks))
476
+
477
+ return all_chunks, chunk_to_doc
478
+
479
+
480
+ @backoff.on_exception(
481
+ backoff.expo,
482
+ exception=(Exception,),
483
+ giveup=lambda e: not _is_retryable_error(e),
484
+ max_tries=5,
485
+ factor=3.0,
486
+ on_backoff=_print_backoff_message,
487
+ )
488
+ async def _get_openai_embeddings_async_one_batch(
489
+ client: AsyncOpenAI, batch: list[str] | list[list[int]], model_name: str, dimensions: int | None
490
+ ):
491
+ try:
492
+ response = await client.embeddings.create(
493
+ model=model_name,
494
+ input=batch,
495
+ dimensions=dimensions if dimensions is not None else omit,
496
+ )
497
+ return [data.embedding for data in response.data]
498
+ except RateLimitError as e:
499
+ raise RateLimitException(e) from e
500
+
501
+
502
+ async def get_chunked_openai_embeddings_async(
503
+ texts: list[str],
504
+ model_name: str = "text-embedding-3-small",
505
+ dimensions: int | None = 512,
506
+ window_size: int = MAX_EMBEDDING_TOKENS,
507
+ overlap: int = 128,
508
+ max_concurrency: int = 100,
509
+ callback: AsyncEmbeddingStreamingCallback | None = None,
510
+ ) -> tuple[list[list[float]], list[int]]:
511
+ """
512
+ Asynchronously get embeddings for a list of texts using OpenAI's embedding model.
513
+ This function uses tiktoken for tokenization, truncates at 8000 tokens, and prints a warning if truncation occurs.
514
+ Concurrency is limited using a semaphore.
515
+ """
516
+
517
+ if model_name != "text-embedding-3-large" and model_name != "text-embedding-3-small":
518
+ assert dimensions is None, f"{model_name} does not have a variable dimension size"
519
+
520
+ all_chunks, chunk_to_doc = chunk_and_tokenize(texts, window_size=window_size, overlap=overlap)
521
+
522
+ # Create batches of 25 texts. Embedding endpoint has a token limit.
523
+ batches = [all_chunks[i : i + 25] for i in range(0, len(all_chunks), 25)]
524
+
525
+ client = get_openai_client_async()
526
+ semaphore = asyncio.Semaphore(max_concurrency)
527
+
528
+ batches_done = 0
529
+ batches_total = len(batches)
530
+
531
+ async def limited_task(batch: list[list[int]]):
532
+ nonlocal batches_done
533
+
534
+ async with semaphore:
535
+ embeddings = await _get_openai_embeddings_async_one_batch(
536
+ client, batch, model_name, dimensions
537
+ )
538
+ batches_done += 1
539
+
540
+ if callback:
541
+ progress = int(batches_done / batches_total * 100)
542
+ await callback(progress)
543
+
544
+ return embeddings
545
+
546
+ # Run tasks concurrently
547
+ tasks = [limited_task(batch) for batch in batches]
548
+ results = await asyncio.gather(*tasks)
549
+
550
+ # Flatten the results
551
+ embeddings = [embedding for batch_result in results for embedding in batch_result]
552
+
553
+ return embeddings, chunk_to_doc
554
+
555
+
556
+ async def get_openai_embeddings_async(
557
+ client: AsyncOpenAI,
558
+ texts: list[str],
559
+ model_name: str = "text-embedding-3-large",
560
+ dimensions: int | None = 3072,
561
+ max_concurrency: int = 100,
562
+ ) -> list[list[float] | None]:
563
+ """
564
+ Asynchronously get embeddings for a list of texts using OpenAI's embedding model.
565
+ This function uses tiktoken for tokenization, truncates at 8000 tokens, and prints a warning if truncation occurs.
566
+ Concurrency is limited using a semaphore.
567
+ """
568
+
569
+ if model_name != "text-embedding-3-large":
570
+ assert dimensions is None, f"{model_name} does not have a variable dimension size"
571
+
572
+ # Tokenize and truncate texts
573
+ tokenizer = tiktoken.get_encoding(DEFAULT_TIKTOKEN_ENCODING)
574
+ truncated_texts: list[str] = []
575
+ for i, text in enumerate(texts):
576
+ tokens = tokenizer.encode(text)
577
+ if len(tokens) > MAX_EMBEDDING_TOKENS:
578
+ print(
579
+ f"Warning: Text at index {i} has been truncated from {len(tokens)} to {MAX_EMBEDDING_TOKENS} tokens."
580
+ )
581
+ tokens = tokens[:MAX_EMBEDDING_TOKENS]
582
+ truncated_texts.append(tokenizer.decode(tokens))
583
+
584
+ semaphore = asyncio.Semaphore(max_concurrency)
585
+
586
+ async def limited_task(texts_batch: list[str]):
587
+ async with semaphore:
588
+ try:
589
+ return await _get_openai_embeddings_async_one_batch(
590
+ client, texts_batch, model_name, dimensions
591
+ )
592
+ except Exception as e:
593
+ print(f"Error in fetch_embeddings: {e}. Returning None.")
594
+ return [None] * len(texts_batch)
595
+
596
+ # Create batches of 1000 texts (OpenAI's current limit per request)
597
+ batches = [truncated_texts[i : i + 1000] for i in range(0, len(truncated_texts), 1000)]
598
+
599
+ # Run tasks concurrently
600
+ tasks = [limited_task(batch) for batch in batches]
601
+ results = await asyncio.gather(*tasks)
602
+
603
+ # Flatten the results
604
+ embeddings = [embedding for batch_result in results for embedding in batch_result]
605
+
606
+ return embeddings
607
+
608
+
609
+ def get_openai_embeddings_sync(
610
+ client: OpenAI,
611
+ texts: list[str],
612
+ model_name: str = "text-embedding-3-large",
613
+ dimensions: int | None = 1536,
614
+ ) -> list[list[float] | None]:
615
+ """
616
+ Synchronously get embeddings for a list of texts using OpenAI's embedding model.
617
+ This function uses tiktoken for tokenization and truncates at 8000 tokens.
618
+ """
619
+ # Tokenize and truncate texts
620
+ tokenizer = tiktoken.get_encoding(DEFAULT_TIKTOKEN_ENCODING)
621
+ truncated_texts: list[str] = []
622
+ for i, text in enumerate(texts):
623
+ tokens = tokenizer.encode(text)
624
+ if len(tokens) > MAX_EMBEDDING_TOKENS:
625
+ print(
626
+ f"Warning: Text at index {i} has been truncated from {len(tokens)} to {MAX_EMBEDDING_TOKENS} tokens."
627
+ )
628
+ tokens = tokens[:MAX_EMBEDDING_TOKENS]
629
+ truncated_texts.append(tokenizer.decode(tokens))
630
+
631
+ # Process in batches of 1000
632
+ embeddings: list[list[float] | None] = []
633
+ for i in range(0, len(truncated_texts), 1000):
634
+ batch = truncated_texts[i : i + 1000]
635
+ try:
636
+ response = client.embeddings.create(
637
+ model=model_name,
638
+ input=batch,
639
+ dimensions=dimensions if dimensions is not None else omit,
640
+ )
641
+ batch_embeddings = [data.embedding for data in response.data]
642
+ embeddings.extend(batch_embeddings)
643
+ except Exception as e:
644
+ print(f"Error in get_openai_embeddings_sync: {e}")
645
+ embeddings.extend([None] * len(batch))
646
+
647
+ return embeddings
648
+
649
+
650
+ def _parse_openai_tool_call(tc: ChatCompletionMessageToolCallUnion) -> ToolCall:
651
+ # Only handle function tool calls, skip custom tool calls
652
+ if tc.type != "function":
653
+ return ToolCall(
654
+ id=tc.id,
655
+ function="unknown",
656
+ arguments={},
657
+ parse_error=f"Unsupported tool call type: {tc.type}",
658
+ type=None,
659
+ )
660
+
661
+ # Attempt to parse the tool call arguments as JSON
662
+ arguments: dict[str, Any] = {}
663
+ try:
664
+ arguments = json.loads(tc.function.arguments)
665
+ parse_error = None
666
+ # If the tool call arguments are not valid JSON, return an empty dict with the error
667
+ except Exception as e:
668
+ arguments = {"__parse_error_raw_args": tc.function.arguments}
669
+ parse_error = f"Couldn't parse tool call arguments as JSON: {e}. Original input: {tc.function.arguments}"
670
+
671
+ return ToolCall(
672
+ id=tc.id,
673
+ function=tc.function.name,
674
+ arguments=arguments,
675
+ parse_error=parse_error,
676
+ type=tc.type,
677
+ )
678
+
679
+
680
+ def parse_openai_completion(response: ChatCompletion | None, model: str) -> LLMOutput:
681
+ if response is None:
682
+ return LLMOutput(
683
+ model=model,
684
+ completions=[],
685
+ errors=[NoResponseException()],
686
+ )
687
+
688
+ # Extract token usage if available
689
+ if response.usage:
690
+ usage = UsageMetrics(
691
+ input=response.usage.prompt_tokens,
692
+ output=response.usage.completion_tokens,
693
+ )
694
+ else:
695
+ logger.warning("OpenAI response did not include usage metrics")
696
+ usage = UsageMetrics()
697
+
698
+ return LLMOutput(
699
+ model=response.model,
700
+ completions=[
701
+ LLMCompletion(
702
+ text=choice.message.content,
703
+ finish_reason=choice.finish_reason,
704
+ tool_calls=(
705
+ [_parse_openai_tool_call(tc) for tc in tcs]
706
+ if (tcs := choice.message.tool_calls)
707
+ else None
708
+ ),
709
+ top_logprobs=(
710
+ [pos.top_logprobs for pos in choice.logprobs.content]
711
+ if choice.logprobs and choice.logprobs.content is not None
712
+ else None
713
+ ),
714
+ )
715
+ for choice in response.choices
716
+ ],
717
+ usage=usage,
718
+ )
719
+
720
+
721
+ async def is_openai_api_key_valid(api_key: str) -> bool:
722
+ """
723
+ Test whether an OpenAI API key is valid or invalid.
724
+
725
+ Args:
726
+ api_key: The OpenAI API key to test.
727
+
728
+ Returns:
729
+ bool: True if the API key is valid, False otherwise.
730
+ """
731
+ client = AsyncOpenAI(api_key=api_key)
732
+
733
+ try:
734
+ # Attempt to make a simple API call with minimal tokens/cost
735
+ await client.chat.completions.create(
736
+ model="gpt-3.5-turbo", messages=[{"role": "user", "content": "hi"}], max_tokens=1
737
+ )
738
+ return True
739
+ except AuthenticationError:
740
+ # API key is invalid
741
+ return False
742
+ except Exception:
743
+ # Any other error means the key might be valid but there's another issue
744
+ # For testing key validity specifically, we'll return False only for auth errors
745
+ return True