docent-python 0.1.20a0__py3-none-any.whl → 0.1.22a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +320 -0
- docent/_llm_util/data_models/simple_svc.py +79 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/model_registry.py +126 -0
- docent/_llm_util/prod_llms.py +454 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/transcript.py +2 -0
- docent/judges/__init__.py +21 -0
- docent/judges/impl.py +232 -0
- docent/judges/types.py +240 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +84 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +95 -0
- docent/judges/util/voting.py +114 -0
- docent/trace.py +19 -1
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/METADATA +7 -1
- docent_python-0.1.22a0.dist-info/RECORD +58 -0
- docent_python-0.1.20a0.dist-info/RECORD +0 -34
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,745 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Literal, cast
|
|
4
|
+
|
|
5
|
+
import backoff
|
|
6
|
+
import tiktoken
|
|
7
|
+
from backoff.types import Details
|
|
8
|
+
|
|
9
|
+
# all errors: https://platform.openai.com/docs/guides/error-codes/api-errors#python-library-error-types
|
|
10
|
+
from openai import (
|
|
11
|
+
APIConnectionError,
|
|
12
|
+
AsyncAzureOpenAI,
|
|
13
|
+
AsyncOpenAI,
|
|
14
|
+
AuthenticationError,
|
|
15
|
+
BadRequestError,
|
|
16
|
+
NotFoundError,
|
|
17
|
+
OpenAI,
|
|
18
|
+
PermissionDeniedError,
|
|
19
|
+
RateLimitError,
|
|
20
|
+
UnprocessableEntityError,
|
|
21
|
+
omit,
|
|
22
|
+
)
|
|
23
|
+
from openai.types.chat import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionAssistantMessageParam,
|
|
26
|
+
ChatCompletionChunk,
|
|
27
|
+
ChatCompletionContentPartTextParam,
|
|
28
|
+
ChatCompletionMessageParam,
|
|
29
|
+
ChatCompletionMessageToolCallParam,
|
|
30
|
+
ChatCompletionMessageToolCallUnion,
|
|
31
|
+
ChatCompletionSystemMessageParam,
|
|
32
|
+
ChatCompletionToolMessageParam,
|
|
33
|
+
ChatCompletionToolParam,
|
|
34
|
+
ChatCompletionUserMessageParam,
|
|
35
|
+
)
|
|
36
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
37
|
+
Function as OpenAIFunctionParam,
|
|
38
|
+
)
|
|
39
|
+
from openai.types.shared_params.function_definition import FunctionDefinition
|
|
40
|
+
|
|
41
|
+
from docent._llm_util.data_models.exceptions import (
|
|
42
|
+
CompletionTooLongException,
|
|
43
|
+
ContextWindowException,
|
|
44
|
+
NoResponseException,
|
|
45
|
+
RateLimitException,
|
|
46
|
+
)
|
|
47
|
+
from docent._llm_util.data_models.llm_output import (
|
|
48
|
+
AsyncEmbeddingStreamingCallback,
|
|
49
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
50
|
+
FinishReasonType,
|
|
51
|
+
LLMCompletion,
|
|
52
|
+
LLMCompletionPartial,
|
|
53
|
+
LLMOutput,
|
|
54
|
+
LLMOutputPartial,
|
|
55
|
+
ToolCallPartial,
|
|
56
|
+
UsageMetrics,
|
|
57
|
+
finalize_llm_output_partial,
|
|
58
|
+
)
|
|
59
|
+
from docent._llm_util.providers.common import async_timeout_ctx
|
|
60
|
+
from docent._log_util import get_logger
|
|
61
|
+
from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
|
|
62
|
+
|
|
63
|
+
logger = get_logger(__name__)
|
|
64
|
+
DEFAULT_TIKTOKEN_ENCODING = "cl100k_base"
|
|
65
|
+
MAX_EMBEDDING_TOKENS = 8000
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _print_backoff_message(e: Details):
|
|
69
|
+
logger.warning(
|
|
70
|
+
f"OpenAI backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _is_retryable_error(e: BaseException) -> bool:
|
|
75
|
+
if (
|
|
76
|
+
isinstance(e, BadRequestError)
|
|
77
|
+
or isinstance(e, ContextWindowException)
|
|
78
|
+
or isinstance(e, AuthenticationError)
|
|
79
|
+
or isinstance(e, PermissionDeniedError)
|
|
80
|
+
or isinstance(e, NotFoundError)
|
|
81
|
+
or isinstance(e, UnprocessableEntityError)
|
|
82
|
+
or isinstance(e, APIConnectionError)
|
|
83
|
+
):
|
|
84
|
+
return False
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _parse_message_content(
|
|
89
|
+
content: str | list[Content],
|
|
90
|
+
) -> str | list[ChatCompletionContentPartTextParam]:
|
|
91
|
+
if isinstance(content, str):
|
|
92
|
+
return content
|
|
93
|
+
else:
|
|
94
|
+
result: list[ChatCompletionContentPartTextParam] = []
|
|
95
|
+
for sub_content in content:
|
|
96
|
+
if sub_content.type == "text":
|
|
97
|
+
result.append(
|
|
98
|
+
ChatCompletionContentPartTextParam(type="text", text=sub_content.text)
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unsupported content type: {sub_content.type}")
|
|
102
|
+
return result
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def parse_chat_messages(messages: list[ChatMessage]) -> list[ChatCompletionMessageParam]:
|
|
106
|
+
result: list[ChatCompletionMessageParam] = []
|
|
107
|
+
|
|
108
|
+
for message in messages:
|
|
109
|
+
if message.role == "user":
|
|
110
|
+
result.append(
|
|
111
|
+
ChatCompletionUserMessageParam(
|
|
112
|
+
role=message.role,
|
|
113
|
+
content=_parse_message_content(message.content),
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
elif message.role == "assistant":
|
|
117
|
+
tool_calls = (
|
|
118
|
+
[
|
|
119
|
+
ChatCompletionMessageToolCallParam(
|
|
120
|
+
id=tool_call.id,
|
|
121
|
+
function=OpenAIFunctionParam(
|
|
122
|
+
name=tool_call.function,
|
|
123
|
+
arguments=json.dumps(tool_call.arguments),
|
|
124
|
+
),
|
|
125
|
+
type="function",
|
|
126
|
+
)
|
|
127
|
+
for tool_call in message.tool_calls
|
|
128
|
+
]
|
|
129
|
+
if message.tool_calls
|
|
130
|
+
else None
|
|
131
|
+
)
|
|
132
|
+
# Redundant code annoyingly necessary due to typechecking, but maybe I'm missing something
|
|
133
|
+
if not tool_calls:
|
|
134
|
+
result.append(
|
|
135
|
+
ChatCompletionAssistantMessageParam(
|
|
136
|
+
role=message.role, content=_parse_message_content(message.content)
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
result.append(
|
|
141
|
+
ChatCompletionAssistantMessageParam(
|
|
142
|
+
role=message.role,
|
|
143
|
+
content=_parse_message_content(message.content),
|
|
144
|
+
tool_calls=tool_calls,
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
elif message.role == "tool":
|
|
148
|
+
result.append(
|
|
149
|
+
ChatCompletionToolMessageParam(
|
|
150
|
+
role=message.role,
|
|
151
|
+
content=_parse_message_content(message.content),
|
|
152
|
+
tool_call_id=str(message.tool_call_id),
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
elif message.role == "system":
|
|
156
|
+
result.append(
|
|
157
|
+
ChatCompletionSystemMessageParam(
|
|
158
|
+
role=message.role,
|
|
159
|
+
content=_parse_message_content(message.content),
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return result
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def parse_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
|
|
167
|
+
"""Convert ToolInfo objects to OpenAI ChatCompletionToolParam format."""
|
|
168
|
+
|
|
169
|
+
result: list[ChatCompletionToolParam] = []
|
|
170
|
+
|
|
171
|
+
for tool in tools:
|
|
172
|
+
result.append(
|
|
173
|
+
ChatCompletionToolParam(
|
|
174
|
+
type="function",
|
|
175
|
+
function=FunctionDefinition(
|
|
176
|
+
name=tool.name,
|
|
177
|
+
description=tool.description,
|
|
178
|
+
parameters=tool.parameters.model_dump(exclude_none=True),
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@backoff.on_exception(
|
|
187
|
+
backoff.expo,
|
|
188
|
+
exception=(Exception,),
|
|
189
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
190
|
+
max_tries=5,
|
|
191
|
+
factor=3.0,
|
|
192
|
+
on_backoff=_print_backoff_message,
|
|
193
|
+
)
|
|
194
|
+
async def get_openai_chat_completion_streaming_async(
|
|
195
|
+
client: AsyncOpenAI,
|
|
196
|
+
streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
|
|
197
|
+
messages: list[ChatMessage],
|
|
198
|
+
model_name: str,
|
|
199
|
+
tools: list[ToolInfo] | None = None,
|
|
200
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
201
|
+
max_new_tokens: int = 32,
|
|
202
|
+
temperature: float = 1.0,
|
|
203
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
204
|
+
logprobs: bool = False,
|
|
205
|
+
top_logprobs: int | None = None,
|
|
206
|
+
timeout: float = 30.0,
|
|
207
|
+
):
|
|
208
|
+
input_messages = parse_chat_messages(messages)
|
|
209
|
+
input_tools = parse_tools(tools) if tools else omit
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
async with async_timeout_ctx(timeout):
|
|
213
|
+
stream = await client.chat.completions.create(
|
|
214
|
+
model=model_name,
|
|
215
|
+
messages=input_messages,
|
|
216
|
+
tools=input_tools,
|
|
217
|
+
tool_choice=tool_choice or omit,
|
|
218
|
+
max_completion_tokens=max_new_tokens,
|
|
219
|
+
temperature=temperature,
|
|
220
|
+
reasoning_effort=reasoning_effort or omit,
|
|
221
|
+
logprobs=logprobs,
|
|
222
|
+
top_logprobs=top_logprobs,
|
|
223
|
+
stream_options={"include_usage": True},
|
|
224
|
+
stream=True,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
llm_output_partial = None
|
|
228
|
+
async for chunk in stream:
|
|
229
|
+
llm_output_partial = update_llm_output(llm_output_partial, chunk)
|
|
230
|
+
if streaming_callback:
|
|
231
|
+
await streaming_callback(finalize_llm_output_partial(llm_output_partial))
|
|
232
|
+
|
|
233
|
+
# Fully parse the partial output
|
|
234
|
+
if llm_output_partial:
|
|
235
|
+
return finalize_llm_output_partial(llm_output_partial)
|
|
236
|
+
else:
|
|
237
|
+
# Streaming did not produce anything
|
|
238
|
+
return LLMOutput(model=model_name, completions=[], errors=[NoResponseException()])
|
|
239
|
+
except (RateLimitError, BadRequestError) as e:
|
|
240
|
+
if e2 := _convert_openai_error(e):
|
|
241
|
+
raise e2 from e
|
|
242
|
+
else:
|
|
243
|
+
raise
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _convert_openai_error(e: Exception):
|
|
247
|
+
if isinstance(e, RateLimitError):
|
|
248
|
+
return RateLimitException(e)
|
|
249
|
+
elif isinstance(e, BadRequestError) and e.code == "context_length_exceeded":
|
|
250
|
+
return ContextWindowException()
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def update_llm_output(llm_output_partial: LLMOutputPartial | None, chunk: ChatCompletionChunk):
|
|
255
|
+
# Collect exisitng outputs
|
|
256
|
+
if llm_output_partial is not None:
|
|
257
|
+
cur_texts: list[str | None] = [c.text for c in llm_output_partial.completions]
|
|
258
|
+
cur_finish_reasons: list[FinishReasonType | None] = [
|
|
259
|
+
c.finish_reason for c in llm_output_partial.completions
|
|
260
|
+
]
|
|
261
|
+
cur_tool_calls_all: list[list[ToolCallPartial | None] | None] = [
|
|
262
|
+
cast(list[ToolCallPartial | None], c.tool_calls) for c in llm_output_partial.completions
|
|
263
|
+
]
|
|
264
|
+
else:
|
|
265
|
+
cur_texts, cur_finish_reasons, cur_tool_calls_all = [], [], []
|
|
266
|
+
|
|
267
|
+
# Define functions for getting and setting values of the current state
|
|
268
|
+
def _get_text(i: int):
|
|
269
|
+
if i >= len(cur_texts):
|
|
270
|
+
return None
|
|
271
|
+
else:
|
|
272
|
+
return cur_texts[i]
|
|
273
|
+
|
|
274
|
+
def _set_text(i: int, text: str):
|
|
275
|
+
if i >= len(cur_texts):
|
|
276
|
+
cur_texts.extend([None] * (i - len(cur_texts) + 1))
|
|
277
|
+
cur_texts[i] = text
|
|
278
|
+
|
|
279
|
+
def _get_finish_reason(i: int):
|
|
280
|
+
if i >= len(cur_finish_reasons) or cur_finish_reasons[i] is None:
|
|
281
|
+
return None
|
|
282
|
+
else:
|
|
283
|
+
return cur_finish_reasons[i]
|
|
284
|
+
|
|
285
|
+
def _set_finish_reason(i: int, finish_reason: FinishReasonType | None):
|
|
286
|
+
if i >= len(cur_finish_reasons):
|
|
287
|
+
cur_finish_reasons.extend([None] * (i - len(cur_finish_reasons) + 1))
|
|
288
|
+
cur_finish_reasons[i] = finish_reason
|
|
289
|
+
|
|
290
|
+
def _get_tool_calls(i: int):
|
|
291
|
+
if i >= len(cur_tool_calls_all):
|
|
292
|
+
return None
|
|
293
|
+
else:
|
|
294
|
+
return cur_tool_calls_all[i]
|
|
295
|
+
|
|
296
|
+
def _get_tool_call(i: int, j: int):
|
|
297
|
+
if i >= len(cur_tool_calls_all):
|
|
298
|
+
return None
|
|
299
|
+
else:
|
|
300
|
+
cur_tool_calls = cur_tool_calls_all[i]
|
|
301
|
+
if cur_tool_calls is None or j >= len(cur_tool_calls):
|
|
302
|
+
return None
|
|
303
|
+
else:
|
|
304
|
+
return cur_tool_calls[j]
|
|
305
|
+
|
|
306
|
+
def _set_tool_call(i: int, j: int, tool_call: ToolCallPartial):
|
|
307
|
+
if i >= len(cur_tool_calls_all):
|
|
308
|
+
cur_tool_calls_all.extend([None] * (i - len(cur_tool_calls_all) + 1))
|
|
309
|
+
|
|
310
|
+
# Add ToolCall to current choice index
|
|
311
|
+
cur_tool_calls = cur_tool_calls_all[i] or []
|
|
312
|
+
if j >= len(cur_tool_calls):
|
|
313
|
+
cur_tool_calls.extend([None] * (j - len(cur_tool_calls) + 1))
|
|
314
|
+
cur_tool_calls[j] = tool_call
|
|
315
|
+
|
|
316
|
+
# Re-update the global array
|
|
317
|
+
cur_tool_calls_all[i] = cur_tool_calls
|
|
318
|
+
|
|
319
|
+
# Update existing completions based on this chunk
|
|
320
|
+
for choice in chunk.choices:
|
|
321
|
+
i, delta = choice.index, choice.delta
|
|
322
|
+
|
|
323
|
+
# Resolve text and finish reason
|
|
324
|
+
_set_text(i, (_get_text(i) or "") + (delta.content or ""))
|
|
325
|
+
_set_finish_reason(i, choice.finish_reason or _get_finish_reason(i))
|
|
326
|
+
|
|
327
|
+
# Tool call resolution is more complicated
|
|
328
|
+
for tc_delta in delta.tool_calls or []:
|
|
329
|
+
tc_idx = tc_delta.index
|
|
330
|
+
tc_function = tc_delta.function.name if tc_delta.function else None
|
|
331
|
+
tc_arguments = tc_delta.function.arguments if tc_delta.function else None
|
|
332
|
+
|
|
333
|
+
old_tool_call = _get_tool_call(i, tc_idx)
|
|
334
|
+
|
|
335
|
+
if old_tool_call:
|
|
336
|
+
tool_call_partial = ToolCallPartial(
|
|
337
|
+
id=old_tool_call.id or tc_delta.id,
|
|
338
|
+
function=(old_tool_call.function or "") + (tc_function or ""),
|
|
339
|
+
arguments_raw=(old_tool_call.arguments_raw or "") + (tc_arguments or ""),
|
|
340
|
+
type="function",
|
|
341
|
+
)
|
|
342
|
+
else:
|
|
343
|
+
tool_call_partial = ToolCallPartial(
|
|
344
|
+
id=tc_delta.id,
|
|
345
|
+
function=tc_function or "",
|
|
346
|
+
arguments_raw=tc_arguments or "",
|
|
347
|
+
type="function",
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
_set_tool_call(i, tc_idx, tool_call_partial)
|
|
351
|
+
|
|
352
|
+
if chunk.usage is not None:
|
|
353
|
+
usage = UsageMetrics(
|
|
354
|
+
input=chunk.usage.prompt_tokens,
|
|
355
|
+
output=chunk.usage.completion_tokens,
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
usage = UsageMetrics()
|
|
359
|
+
|
|
360
|
+
completions: list[LLMCompletionPartial] = []
|
|
361
|
+
# TOOD assert all lengths are same
|
|
362
|
+
for i in range(len(cur_texts)):
|
|
363
|
+
completions.append(
|
|
364
|
+
LLMCompletionPartial(
|
|
365
|
+
text=_get_text(i),
|
|
366
|
+
tool_calls=_get_tool_calls(i),
|
|
367
|
+
finish_reason=_get_finish_reason(i),
|
|
368
|
+
)
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
return LLMOutputPartial(
|
|
372
|
+
completions=completions, # type: ignore[arg-type]
|
|
373
|
+
model=chunk.model,
|
|
374
|
+
usage=usage,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@backoff.on_exception(
|
|
379
|
+
backoff.expo,
|
|
380
|
+
exception=(Exception,),
|
|
381
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
382
|
+
max_tries=5,
|
|
383
|
+
factor=3.0,
|
|
384
|
+
on_backoff=_print_backoff_message,
|
|
385
|
+
)
|
|
386
|
+
async def get_openai_chat_completion_async(
|
|
387
|
+
client: AsyncOpenAI,
|
|
388
|
+
messages: list[ChatMessage],
|
|
389
|
+
model_name: str,
|
|
390
|
+
tools: list[ToolInfo] | None = None,
|
|
391
|
+
tool_choice: Literal["auto", "none", "required"] | None = None,
|
|
392
|
+
max_new_tokens: int = 32,
|
|
393
|
+
temperature: float = 1.0,
|
|
394
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
395
|
+
logprobs: bool = False,
|
|
396
|
+
top_logprobs: int | None = None,
|
|
397
|
+
timeout: float = 5.0,
|
|
398
|
+
) -> LLMOutput:
|
|
399
|
+
input_messages = parse_chat_messages(messages)
|
|
400
|
+
input_tools = parse_tools(tools) if tools else omit
|
|
401
|
+
|
|
402
|
+
try:
|
|
403
|
+
async with async_timeout_ctx(timeout): # type: ignore
|
|
404
|
+
raw_output = await client.chat.completions.create(
|
|
405
|
+
model=model_name,
|
|
406
|
+
messages=input_messages,
|
|
407
|
+
tools=input_tools,
|
|
408
|
+
tool_choice=tool_choice or omit,
|
|
409
|
+
max_completion_tokens=max_new_tokens,
|
|
410
|
+
temperature=temperature,
|
|
411
|
+
reasoning_effort=reasoning_effort or omit,
|
|
412
|
+
logprobs=logprobs,
|
|
413
|
+
top_logprobs=top_logprobs,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# If the completion is empty and was truncated (likely due to too much reasoning), raise an exception
|
|
417
|
+
output = parse_openai_completion(raw_output, model_name)
|
|
418
|
+
if output.first and output.first.finish_reason == "length" and output.first.no_text:
|
|
419
|
+
raise CompletionTooLongException(
|
|
420
|
+
"Completion empty due to truncation. Consider increasing max_new_tokens."
|
|
421
|
+
)
|
|
422
|
+
for c in output.completions:
|
|
423
|
+
if c.finish_reason == "length":
|
|
424
|
+
logger.warning(
|
|
425
|
+
"Completion truncated due to length; consider increasing max_new_tokens."
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
return output
|
|
429
|
+
except (RateLimitError, BadRequestError) as e:
|
|
430
|
+
if e2 := _convert_openai_error(e):
|
|
431
|
+
raise e2 from e
|
|
432
|
+
else:
|
|
433
|
+
raise
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def get_openai_client_async(api_key: str | None = None) -> AsyncOpenAI:
|
|
437
|
+
return AsyncOpenAI(api_key=api_key) if api_key else AsyncOpenAI()
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def get_azure_openai_client_async(api_key: str | None = None) -> AsyncAzureOpenAI:
|
|
441
|
+
return AsyncAzureOpenAI(api_key=api_key) if api_key else AsyncAzureOpenAI()
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def chunk_and_tokenize(
|
|
445
|
+
text: list[str],
|
|
446
|
+
window_size: int = 8191,
|
|
447
|
+
overlap: int = 128,
|
|
448
|
+
) -> tuple[list[list[int]], list[int]]:
|
|
449
|
+
"""Encode a list of text into a list of token ids."""
|
|
450
|
+
|
|
451
|
+
def _chunk_tokens(tokens: list[int], window_size: int, overlap: int) -> list[list[int]]:
|
|
452
|
+
"""Compute list chunks with overlap."""
|
|
453
|
+
if overlap >= window_size:
|
|
454
|
+
raise ValueError("overlap must be smaller than window_size")
|
|
455
|
+
|
|
456
|
+
stride = window_size - overlap
|
|
457
|
+
chunks: list[list[int]] = []
|
|
458
|
+
for i in range(0, len(tokens), stride):
|
|
459
|
+
chunks.append(tokens[i : i + window_size])
|
|
460
|
+
return chunks
|
|
461
|
+
|
|
462
|
+
encoding = tiktoken.get_encoding(DEFAULT_TIKTOKEN_ENCODING)
|
|
463
|
+
|
|
464
|
+
all_chunks: list[list[int]] = []
|
|
465
|
+
chunk_to_doc: list[int] = []
|
|
466
|
+
|
|
467
|
+
for i, item in enumerate(text):
|
|
468
|
+
tokens = encoding.encode(item)
|
|
469
|
+
if len(tokens) <= window_size:
|
|
470
|
+
chunks = [tokens]
|
|
471
|
+
else:
|
|
472
|
+
chunks = _chunk_tokens(tokens, window_size, overlap)
|
|
473
|
+
|
|
474
|
+
all_chunks.extend(chunks)
|
|
475
|
+
chunk_to_doc.extend([i] * len(chunks))
|
|
476
|
+
|
|
477
|
+
return all_chunks, chunk_to_doc
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@backoff.on_exception(
|
|
481
|
+
backoff.expo,
|
|
482
|
+
exception=(Exception,),
|
|
483
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
484
|
+
max_tries=5,
|
|
485
|
+
factor=3.0,
|
|
486
|
+
on_backoff=_print_backoff_message,
|
|
487
|
+
)
|
|
488
|
+
async def _get_openai_embeddings_async_one_batch(
|
|
489
|
+
client: AsyncOpenAI, batch: list[str] | list[list[int]], model_name: str, dimensions: int | None
|
|
490
|
+
):
|
|
491
|
+
try:
|
|
492
|
+
response = await client.embeddings.create(
|
|
493
|
+
model=model_name,
|
|
494
|
+
input=batch,
|
|
495
|
+
dimensions=dimensions if dimensions is not None else omit,
|
|
496
|
+
)
|
|
497
|
+
return [data.embedding for data in response.data]
|
|
498
|
+
except RateLimitError as e:
|
|
499
|
+
raise RateLimitException(e) from e
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
async def get_chunked_openai_embeddings_async(
|
|
503
|
+
texts: list[str],
|
|
504
|
+
model_name: str = "text-embedding-3-small",
|
|
505
|
+
dimensions: int | None = 512,
|
|
506
|
+
window_size: int = MAX_EMBEDDING_TOKENS,
|
|
507
|
+
overlap: int = 128,
|
|
508
|
+
max_concurrency: int = 100,
|
|
509
|
+
callback: AsyncEmbeddingStreamingCallback | None = None,
|
|
510
|
+
) -> tuple[list[list[float]], list[int]]:
|
|
511
|
+
"""
|
|
512
|
+
Asynchronously get embeddings for a list of texts using OpenAI's embedding model.
|
|
513
|
+
This function uses tiktoken for tokenization, truncates at 8000 tokens, and prints a warning if truncation occurs.
|
|
514
|
+
Concurrency is limited using a semaphore.
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
if model_name != "text-embedding-3-large" and model_name != "text-embedding-3-small":
|
|
518
|
+
assert dimensions is None, f"{model_name} does not have a variable dimension size"
|
|
519
|
+
|
|
520
|
+
all_chunks, chunk_to_doc = chunk_and_tokenize(texts, window_size=window_size, overlap=overlap)
|
|
521
|
+
|
|
522
|
+
# Create batches of 25 texts. Embedding endpoint has a token limit.
|
|
523
|
+
batches = [all_chunks[i : i + 25] for i in range(0, len(all_chunks), 25)]
|
|
524
|
+
|
|
525
|
+
client = get_openai_client_async()
|
|
526
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
|
527
|
+
|
|
528
|
+
batches_done = 0
|
|
529
|
+
batches_total = len(batches)
|
|
530
|
+
|
|
531
|
+
async def limited_task(batch: list[list[int]]):
|
|
532
|
+
nonlocal batches_done
|
|
533
|
+
|
|
534
|
+
async with semaphore:
|
|
535
|
+
embeddings = await _get_openai_embeddings_async_one_batch(
|
|
536
|
+
client, batch, model_name, dimensions
|
|
537
|
+
)
|
|
538
|
+
batches_done += 1
|
|
539
|
+
|
|
540
|
+
if callback:
|
|
541
|
+
progress = int(batches_done / batches_total * 100)
|
|
542
|
+
await callback(progress)
|
|
543
|
+
|
|
544
|
+
return embeddings
|
|
545
|
+
|
|
546
|
+
# Run tasks concurrently
|
|
547
|
+
tasks = [limited_task(batch) for batch in batches]
|
|
548
|
+
results = await asyncio.gather(*tasks)
|
|
549
|
+
|
|
550
|
+
# Flatten the results
|
|
551
|
+
embeddings = [embedding for batch_result in results for embedding in batch_result]
|
|
552
|
+
|
|
553
|
+
return embeddings, chunk_to_doc
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
async def get_openai_embeddings_async(
|
|
557
|
+
client: AsyncOpenAI,
|
|
558
|
+
texts: list[str],
|
|
559
|
+
model_name: str = "text-embedding-3-large",
|
|
560
|
+
dimensions: int | None = 3072,
|
|
561
|
+
max_concurrency: int = 100,
|
|
562
|
+
) -> list[list[float] | None]:
|
|
563
|
+
"""
|
|
564
|
+
Asynchronously get embeddings for a list of texts using OpenAI's embedding model.
|
|
565
|
+
This function uses tiktoken for tokenization, truncates at 8000 tokens, and prints a warning if truncation occurs.
|
|
566
|
+
Concurrency is limited using a semaphore.
|
|
567
|
+
"""
|
|
568
|
+
|
|
569
|
+
if model_name != "text-embedding-3-large":
|
|
570
|
+
assert dimensions is None, f"{model_name} does not have a variable dimension size"
|
|
571
|
+
|
|
572
|
+
# Tokenize and truncate texts
|
|
573
|
+
tokenizer = tiktoken.get_encoding(DEFAULT_TIKTOKEN_ENCODING)
|
|
574
|
+
truncated_texts: list[str] = []
|
|
575
|
+
for i, text in enumerate(texts):
|
|
576
|
+
tokens = tokenizer.encode(text)
|
|
577
|
+
if len(tokens) > MAX_EMBEDDING_TOKENS:
|
|
578
|
+
print(
|
|
579
|
+
f"Warning: Text at index {i} has been truncated from {len(tokens)} to {MAX_EMBEDDING_TOKENS} tokens."
|
|
580
|
+
)
|
|
581
|
+
tokens = tokens[:MAX_EMBEDDING_TOKENS]
|
|
582
|
+
truncated_texts.append(tokenizer.decode(tokens))
|
|
583
|
+
|
|
584
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
|
585
|
+
|
|
586
|
+
async def limited_task(texts_batch: list[str]):
|
|
587
|
+
async with semaphore:
|
|
588
|
+
try:
|
|
589
|
+
return await _get_openai_embeddings_async_one_batch(
|
|
590
|
+
client, texts_batch, model_name, dimensions
|
|
591
|
+
)
|
|
592
|
+
except Exception as e:
|
|
593
|
+
print(f"Error in fetch_embeddings: {e}. Returning None.")
|
|
594
|
+
return [None] * len(texts_batch)
|
|
595
|
+
|
|
596
|
+
# Create batches of 1000 texts (OpenAI's current limit per request)
|
|
597
|
+
batches = [truncated_texts[i : i + 1000] for i in range(0, len(truncated_texts), 1000)]
|
|
598
|
+
|
|
599
|
+
# Run tasks concurrently
|
|
600
|
+
tasks = [limited_task(batch) for batch in batches]
|
|
601
|
+
results = await asyncio.gather(*tasks)
|
|
602
|
+
|
|
603
|
+
# Flatten the results
|
|
604
|
+
embeddings = [embedding for batch_result in results for embedding in batch_result]
|
|
605
|
+
|
|
606
|
+
return embeddings
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
def get_openai_embeddings_sync(
|
|
610
|
+
client: OpenAI,
|
|
611
|
+
texts: list[str],
|
|
612
|
+
model_name: str = "text-embedding-3-large",
|
|
613
|
+
dimensions: int | None = 1536,
|
|
614
|
+
) -> list[list[float] | None]:
|
|
615
|
+
"""
|
|
616
|
+
Synchronously get embeddings for a list of texts using OpenAI's embedding model.
|
|
617
|
+
This function uses tiktoken for tokenization and truncates at 8000 tokens.
|
|
618
|
+
"""
|
|
619
|
+
# Tokenize and truncate texts
|
|
620
|
+
tokenizer = tiktoken.get_encoding(DEFAULT_TIKTOKEN_ENCODING)
|
|
621
|
+
truncated_texts: list[str] = []
|
|
622
|
+
for i, text in enumerate(texts):
|
|
623
|
+
tokens = tokenizer.encode(text)
|
|
624
|
+
if len(tokens) > MAX_EMBEDDING_TOKENS:
|
|
625
|
+
print(
|
|
626
|
+
f"Warning: Text at index {i} has been truncated from {len(tokens)} to {MAX_EMBEDDING_TOKENS} tokens."
|
|
627
|
+
)
|
|
628
|
+
tokens = tokens[:MAX_EMBEDDING_TOKENS]
|
|
629
|
+
truncated_texts.append(tokenizer.decode(tokens))
|
|
630
|
+
|
|
631
|
+
# Process in batches of 1000
|
|
632
|
+
embeddings: list[list[float] | None] = []
|
|
633
|
+
for i in range(0, len(truncated_texts), 1000):
|
|
634
|
+
batch = truncated_texts[i : i + 1000]
|
|
635
|
+
try:
|
|
636
|
+
response = client.embeddings.create(
|
|
637
|
+
model=model_name,
|
|
638
|
+
input=batch,
|
|
639
|
+
dimensions=dimensions if dimensions is not None else omit,
|
|
640
|
+
)
|
|
641
|
+
batch_embeddings = [data.embedding for data in response.data]
|
|
642
|
+
embeddings.extend(batch_embeddings)
|
|
643
|
+
except Exception as e:
|
|
644
|
+
print(f"Error in get_openai_embeddings_sync: {e}")
|
|
645
|
+
embeddings.extend([None] * len(batch))
|
|
646
|
+
|
|
647
|
+
return embeddings
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def _parse_openai_tool_call(tc: ChatCompletionMessageToolCallUnion) -> ToolCall:
|
|
651
|
+
# Only handle function tool calls, skip custom tool calls
|
|
652
|
+
if tc.type != "function":
|
|
653
|
+
return ToolCall(
|
|
654
|
+
id=tc.id,
|
|
655
|
+
function="unknown",
|
|
656
|
+
arguments={},
|
|
657
|
+
parse_error=f"Unsupported tool call type: {tc.type}",
|
|
658
|
+
type=None,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# Attempt to parse the tool call arguments as JSON
|
|
662
|
+
arguments: dict[str, Any] = {}
|
|
663
|
+
try:
|
|
664
|
+
arguments = json.loads(tc.function.arguments)
|
|
665
|
+
parse_error = None
|
|
666
|
+
# If the tool call arguments are not valid JSON, return an empty dict with the error
|
|
667
|
+
except Exception as e:
|
|
668
|
+
arguments = {"__parse_error_raw_args": tc.function.arguments}
|
|
669
|
+
parse_error = f"Couldn't parse tool call arguments as JSON: {e}. Original input: {tc.function.arguments}"
|
|
670
|
+
|
|
671
|
+
return ToolCall(
|
|
672
|
+
id=tc.id,
|
|
673
|
+
function=tc.function.name,
|
|
674
|
+
arguments=arguments,
|
|
675
|
+
parse_error=parse_error,
|
|
676
|
+
type=tc.type,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def parse_openai_completion(response: ChatCompletion | None, model: str) -> LLMOutput:
|
|
681
|
+
if response is None:
|
|
682
|
+
return LLMOutput(
|
|
683
|
+
model=model,
|
|
684
|
+
completions=[],
|
|
685
|
+
errors=[NoResponseException()],
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Extract token usage if available
|
|
689
|
+
if response.usage:
|
|
690
|
+
usage = UsageMetrics(
|
|
691
|
+
input=response.usage.prompt_tokens,
|
|
692
|
+
output=response.usage.completion_tokens,
|
|
693
|
+
)
|
|
694
|
+
else:
|
|
695
|
+
logger.warning("OpenAI response did not include usage metrics")
|
|
696
|
+
usage = UsageMetrics()
|
|
697
|
+
|
|
698
|
+
return LLMOutput(
|
|
699
|
+
model=response.model,
|
|
700
|
+
completions=[
|
|
701
|
+
LLMCompletion(
|
|
702
|
+
text=choice.message.content,
|
|
703
|
+
finish_reason=choice.finish_reason,
|
|
704
|
+
tool_calls=(
|
|
705
|
+
[_parse_openai_tool_call(tc) for tc in tcs]
|
|
706
|
+
if (tcs := choice.message.tool_calls)
|
|
707
|
+
else None
|
|
708
|
+
),
|
|
709
|
+
top_logprobs=(
|
|
710
|
+
[pos.top_logprobs for pos in choice.logprobs.content]
|
|
711
|
+
if choice.logprobs and choice.logprobs.content is not None
|
|
712
|
+
else None
|
|
713
|
+
),
|
|
714
|
+
)
|
|
715
|
+
for choice in response.choices
|
|
716
|
+
],
|
|
717
|
+
usage=usage,
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
async def is_openai_api_key_valid(api_key: str) -> bool:
|
|
722
|
+
"""
|
|
723
|
+
Test whether an OpenAI API key is valid or invalid.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
api_key: The OpenAI API key to test.
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
bool: True if the API key is valid, False otherwise.
|
|
730
|
+
"""
|
|
731
|
+
client = AsyncOpenAI(api_key=api_key)
|
|
732
|
+
|
|
733
|
+
try:
|
|
734
|
+
# Attempt to make a simple API call with minimal tokens/cost
|
|
735
|
+
await client.chat.completions.create(
|
|
736
|
+
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "hi"}], max_tokens=1
|
|
737
|
+
)
|
|
738
|
+
return True
|
|
739
|
+
except AuthenticationError:
|
|
740
|
+
# API key is invalid
|
|
741
|
+
return False
|
|
742
|
+
except Exception:
|
|
743
|
+
# Any other error means the key might be valid but there's another issue
|
|
744
|
+
# For testing key validity specifically, we'll return False only for auth errors
|
|
745
|
+
return True
|