docent-python 0.1.20a0__py3-none-any.whl → 0.1.22a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +320 -0
- docent/_llm_util/data_models/simple_svc.py +79 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/model_registry.py +126 -0
- docent/_llm_util/prod_llms.py +454 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/transcript.py +2 -0
- docent/judges/__init__.py +21 -0
- docent/judges/impl.py +232 -0
- docent/judges/types.py +240 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +84 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +95 -0
- docent/judges/util/voting.py +114 -0
- docent/trace.py +19 -1
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/METADATA +7 -1
- docent_python-0.1.22a0.dist-info/RECORD +58 -0
- docent_python-0.1.20a0.dist-info/RECORD +0 -34
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
from typing import Any, Literal, cast
|
|
2
|
+
|
|
3
|
+
import backoff
|
|
4
|
+
|
|
5
|
+
# all errors: https://docs.anthropic.com/en/api/errors
|
|
6
|
+
from anthropic import (
|
|
7
|
+
AsyncAnthropic,
|
|
8
|
+
AuthenticationError,
|
|
9
|
+
BadRequestError,
|
|
10
|
+
NotFoundError,
|
|
11
|
+
PermissionDeniedError,
|
|
12
|
+
RateLimitError,
|
|
13
|
+
UnprocessableEntityError,
|
|
14
|
+
)
|
|
15
|
+
from anthropic._types import NOT_GIVEN
|
|
16
|
+
from anthropic.types import (
|
|
17
|
+
InputJSONDelta,
|
|
18
|
+
Message,
|
|
19
|
+
MessageParam,
|
|
20
|
+
RawContentBlockDeltaEvent,
|
|
21
|
+
RawContentBlockStartEvent,
|
|
22
|
+
RawContentBlockStopEvent,
|
|
23
|
+
RawMessageDeltaEvent,
|
|
24
|
+
RawMessageStartEvent,
|
|
25
|
+
RawMessageStreamEvent,
|
|
26
|
+
SignatureDelta,
|
|
27
|
+
TextBlockParam,
|
|
28
|
+
TextDelta,
|
|
29
|
+
ThinkingDelta,
|
|
30
|
+
ToolChoiceAnyParam,
|
|
31
|
+
ToolChoiceAutoParam,
|
|
32
|
+
ToolChoiceParam,
|
|
33
|
+
ToolParam,
|
|
34
|
+
ToolResultBlockParam,
|
|
35
|
+
ToolUseBlockParam,
|
|
36
|
+
)
|
|
37
|
+
from backoff.types import Details
|
|
38
|
+
|
|
39
|
+
from docent._llm_util.data_models.exceptions import (
|
|
40
|
+
CompletionTooLongException,
|
|
41
|
+
ContextWindowException,
|
|
42
|
+
NoResponseException,
|
|
43
|
+
RateLimitException,
|
|
44
|
+
)
|
|
45
|
+
from docent._llm_util.data_models.llm_output import (
|
|
46
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
47
|
+
FinishReasonType,
|
|
48
|
+
LLMCompletion,
|
|
49
|
+
LLMCompletionPartial,
|
|
50
|
+
LLMOutput,
|
|
51
|
+
LLMOutputPartial,
|
|
52
|
+
ToolCallPartial,
|
|
53
|
+
UsageMetrics,
|
|
54
|
+
finalize_llm_output_partial,
|
|
55
|
+
)
|
|
56
|
+
from docent._llm_util.providers.common import (
|
|
57
|
+
async_timeout_ctx,
|
|
58
|
+
reasoning_budget,
|
|
59
|
+
)
|
|
60
|
+
from docent._log_util import get_logger
|
|
61
|
+
from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
|
|
62
|
+
|
|
63
|
+
logger = get_logger(__name__)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _print_backoff_message(e: Details):
|
|
67
|
+
logger.warning(
|
|
68
|
+
f"Anthropic backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _is_retryable_error(e: BaseException) -> bool:
|
|
73
|
+
if (
|
|
74
|
+
isinstance(e, BadRequestError)
|
|
75
|
+
or isinstance(e, ContextWindowException)
|
|
76
|
+
or isinstance(e, AuthenticationError)
|
|
77
|
+
or isinstance(e, NotImplementedError)
|
|
78
|
+
or isinstance(e, PermissionDeniedError)
|
|
79
|
+
or isinstance(e, NotFoundError)
|
|
80
|
+
or isinstance(e, UnprocessableEntityError)
|
|
81
|
+
):
|
|
82
|
+
return False
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _parse_message_content(content: str | list[Content]) -> str | list[TextBlockParam]:
|
|
87
|
+
if isinstance(content, str):
|
|
88
|
+
return content
|
|
89
|
+
else:
|
|
90
|
+
result: list[TextBlockParam] = []
|
|
91
|
+
for sub_content in content:
|
|
92
|
+
if sub_content.type == "text":
|
|
93
|
+
result.append(TextBlockParam(text=sub_content.text, type="text"))
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unsupported content type: {sub_content.type}")
|
|
96
|
+
return result
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def parse_chat_messages(messages: list[ChatMessage]) -> tuple[str | None, list[MessageParam]]:
|
|
100
|
+
result: list[MessageParam] = []
|
|
101
|
+
system_prompt: str | None = None
|
|
102
|
+
|
|
103
|
+
for message in messages:
|
|
104
|
+
if message.role == "user":
|
|
105
|
+
result.append(
|
|
106
|
+
MessageParam(
|
|
107
|
+
role=message.role,
|
|
108
|
+
content=_parse_message_content(message.content),
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
elif message.role == "assistant":
|
|
112
|
+
message_content = _parse_message_content(message.content)
|
|
113
|
+
# Build content list without creating empty text blocks
|
|
114
|
+
if isinstance(message_content, str):
|
|
115
|
+
stripped = message_content.strip()
|
|
116
|
+
all_content = cast(
|
|
117
|
+
list[TextBlockParam | ToolUseBlockParam],
|
|
118
|
+
([TextBlockParam(text=stripped, type="text")] if stripped else []),
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
all_content = cast(list[TextBlockParam | ToolUseBlockParam], message_content)
|
|
122
|
+
for tool_call in message.tool_calls or []:
|
|
123
|
+
all_content.append(
|
|
124
|
+
ToolUseBlockParam(
|
|
125
|
+
id=tool_call.id,
|
|
126
|
+
input=tool_call.arguments,
|
|
127
|
+
name=tool_call.function,
|
|
128
|
+
type="tool_use",
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
result.append(
|
|
132
|
+
MessageParam(
|
|
133
|
+
role="assistant",
|
|
134
|
+
content=all_content,
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
elif message.role == "tool":
|
|
138
|
+
result.append(
|
|
139
|
+
MessageParam(
|
|
140
|
+
role="user",
|
|
141
|
+
content=[
|
|
142
|
+
ToolResultBlockParam(
|
|
143
|
+
tool_use_id=str(message.tool_call_id),
|
|
144
|
+
type="tool_result",
|
|
145
|
+
content=_parse_message_content(message.content),
|
|
146
|
+
is_error=message.error is not None,
|
|
147
|
+
)
|
|
148
|
+
],
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
elif message.role == "system":
|
|
152
|
+
system_prompt = message.text
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Unknown message role: {message.role}")
|
|
155
|
+
|
|
156
|
+
return system_prompt, result
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def parse_tools(tools: list[ToolInfo]) -> list[ToolParam]:
|
|
160
|
+
return [
|
|
161
|
+
ToolParam(
|
|
162
|
+
name=tool.name,
|
|
163
|
+
description=tool.description,
|
|
164
|
+
input_schema=tool.parameters.model_dump(exclude_none=True),
|
|
165
|
+
)
|
|
166
|
+
for tool in tools
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _parse_tool_choice(tool_choice: Literal["auto", "required"] | None) -> ToolChoiceParam | None:
|
|
171
|
+
if tool_choice is None:
|
|
172
|
+
return None
|
|
173
|
+
elif tool_choice == "auto":
|
|
174
|
+
return ToolChoiceAutoParam(type="auto")
|
|
175
|
+
elif tool_choice == "required":
|
|
176
|
+
return ToolChoiceAnyParam(type="any")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _convert_anthropic_error(e: Exception):
|
|
180
|
+
if isinstance(e, BadRequestError):
|
|
181
|
+
if "context limit" in e.message.lower():
|
|
182
|
+
return ContextWindowException()
|
|
183
|
+
if isinstance(e, RateLimitError):
|
|
184
|
+
return RateLimitException(e)
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@backoff.on_exception(
|
|
189
|
+
backoff.expo,
|
|
190
|
+
exception=(Exception),
|
|
191
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
192
|
+
max_tries=5,
|
|
193
|
+
factor=3.0,
|
|
194
|
+
on_backoff=_print_backoff_message,
|
|
195
|
+
)
|
|
196
|
+
async def get_anthropic_chat_completion_streaming_async(
|
|
197
|
+
client: AsyncAnthropic,
|
|
198
|
+
streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
|
|
199
|
+
messages: list[ChatMessage],
|
|
200
|
+
model_name: str,
|
|
201
|
+
tools: list[ToolInfo] | None = None,
|
|
202
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
203
|
+
max_new_tokens: int = 32,
|
|
204
|
+
temperature: float = 1.0,
|
|
205
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
206
|
+
logprobs: bool = False,
|
|
207
|
+
top_logprobs: int | None = None,
|
|
208
|
+
timeout: float = 5.0,
|
|
209
|
+
):
|
|
210
|
+
if logprobs or top_logprobs is not None:
|
|
211
|
+
raise NotImplementedError(
|
|
212
|
+
"We have not implemented logprobs or top_logprobs for Anthropic yet."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
system, input_messages = parse_chat_messages(messages)
|
|
216
|
+
input_tools = parse_tools(tools) if tools else NOT_GIVEN
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
async with async_timeout_ctx(timeout):
|
|
220
|
+
stream = await client.messages.create(
|
|
221
|
+
model=model_name,
|
|
222
|
+
messages=input_messages,
|
|
223
|
+
thinking=(
|
|
224
|
+
{
|
|
225
|
+
"type": "enabled",
|
|
226
|
+
"budget_tokens": reasoning_budget(max_new_tokens, reasoning_effort),
|
|
227
|
+
}
|
|
228
|
+
if reasoning_effort
|
|
229
|
+
else NOT_GIVEN
|
|
230
|
+
),
|
|
231
|
+
tools=input_tools,
|
|
232
|
+
tool_choice=_parse_tool_choice(tool_choice) or NOT_GIVEN,
|
|
233
|
+
max_tokens=max_new_tokens,
|
|
234
|
+
temperature=temperature,
|
|
235
|
+
system=system if system is not None else NOT_GIVEN,
|
|
236
|
+
stream=True,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
llm_output_partial = None
|
|
240
|
+
async for chunk in stream:
|
|
241
|
+
llm_output_partial = update_llm_output(llm_output_partial, chunk)
|
|
242
|
+
if streaming_callback:
|
|
243
|
+
await streaming_callback(finalize_llm_output_partial(llm_output_partial))
|
|
244
|
+
|
|
245
|
+
# Fully parse the partial output
|
|
246
|
+
if llm_output_partial:
|
|
247
|
+
return finalize_llm_output_partial(llm_output_partial)
|
|
248
|
+
else:
|
|
249
|
+
# Streaming did not produce anything
|
|
250
|
+
return LLMOutput(model=model_name, completions=[], errors=[NoResponseException()])
|
|
251
|
+
except (RateLimitError, BadRequestError) as e:
|
|
252
|
+
if e2 := _convert_anthropic_error(e):
|
|
253
|
+
raise e2 from e
|
|
254
|
+
else:
|
|
255
|
+
raise
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
FINISH_REASON_MAP: dict[str, FinishReasonType] = {
|
|
259
|
+
"end_turn": "stop",
|
|
260
|
+
"max_tokens": "length",
|
|
261
|
+
"stop_sequence": "stop",
|
|
262
|
+
"tool_use": "tool_calls",
|
|
263
|
+
"refusal": "refusal",
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def update_llm_output(
|
|
268
|
+
llm_output_partial: LLMOutputPartial | None,
|
|
269
|
+
chunk: RawMessageStreamEvent,
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
Note that Anthropic only allows one message to be streamed at a time.
|
|
273
|
+
Thus there can only be one completion.
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
usage: UsageMetrics = llm_output_partial.usage if llm_output_partial else UsageMetrics()
|
|
277
|
+
|
|
278
|
+
if llm_output_partial is not None:
|
|
279
|
+
cur_text: str | None = llm_output_partial.completions[0].text
|
|
280
|
+
cur_reasoning_tokens: str | None = llm_output_partial.completions[0].reasoning_tokens
|
|
281
|
+
cur_finish_reason: FinishReasonType | None = llm_output_partial.completions[0].finish_reason
|
|
282
|
+
cur_tool_calls: list[ToolCallPartial | None] | None = llm_output_partial.completions[0].tool_calls # type: ignore[assignment]
|
|
283
|
+
cur_model = llm_output_partial.model
|
|
284
|
+
else:
|
|
285
|
+
cur_text, cur_reasoning_tokens, cur_finish_reason, cur_model = None, None, None, None
|
|
286
|
+
cur_tool_calls = None
|
|
287
|
+
|
|
288
|
+
if isinstance(chunk, RawMessageStartEvent):
|
|
289
|
+
cur_model = chunk.message.model
|
|
290
|
+
elif isinstance(chunk, RawContentBlockStartEvent):
|
|
291
|
+
# If a tool_use block starts, initialize a ToolCallPartial slot using the block index
|
|
292
|
+
content_block = chunk.content_block
|
|
293
|
+
if content_block.type == "tool_use":
|
|
294
|
+
# Ensure the tool_calls array exists and is long enough
|
|
295
|
+
index = chunk.index
|
|
296
|
+
cur_tool_calls = cur_tool_calls or []
|
|
297
|
+
if index >= len(cur_tool_calls):
|
|
298
|
+
cur_tool_calls.extend([None] * (index - len(cur_tool_calls) + 1))
|
|
299
|
+
|
|
300
|
+
# Initialize the partial with id/name; arguments will stream via InputJSONDelta
|
|
301
|
+
cur_tool_calls[index] = ToolCallPartial(
|
|
302
|
+
id=content_block.id,
|
|
303
|
+
function=content_block.name,
|
|
304
|
+
arguments_raw="",
|
|
305
|
+
type="function",
|
|
306
|
+
)
|
|
307
|
+
elif isinstance(chunk, RawContentBlockDeltaEvent):
|
|
308
|
+
if isinstance(chunk.delta, TextDelta):
|
|
309
|
+
cur_text = (cur_text or "") + chunk.delta.text
|
|
310
|
+
elif isinstance(chunk.delta, ThinkingDelta):
|
|
311
|
+
cur_reasoning_tokens = (cur_reasoning_tokens or "") + chunk.delta.thinking
|
|
312
|
+
elif isinstance(chunk.delta, InputJSONDelta):
|
|
313
|
+
# Append streamed JSON into the corresponding ToolCallPartial
|
|
314
|
+
index = chunk.index
|
|
315
|
+
if (
|
|
316
|
+
cur_tool_calls is None
|
|
317
|
+
or index >= len(cur_tool_calls)
|
|
318
|
+
or cur_tool_calls[index] is None
|
|
319
|
+
):
|
|
320
|
+
# This should not happen with a well-behaved API, log and skip
|
|
321
|
+
logger.warning(
|
|
322
|
+
f"Received InputJSONDelta before start event at index {index}, skipping"
|
|
323
|
+
)
|
|
324
|
+
else:
|
|
325
|
+
cur_tool_calls[index] = ToolCallPartial(
|
|
326
|
+
id=cur_tool_calls[index].id, # type: ignore[union-attr]
|
|
327
|
+
function=cur_tool_calls[index].function, # type: ignore[union-attr]
|
|
328
|
+
arguments_raw=(cur_tool_calls[index].arguments_raw or "") + chunk.delta.partial_json, # type: ignore[union-attr]
|
|
329
|
+
type="function",
|
|
330
|
+
)
|
|
331
|
+
elif isinstance(chunk.delta, SignatureDelta):
|
|
332
|
+
logger.debug(
|
|
333
|
+
"Anthropic streamed thinking signature block; we should support this soon."
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(f"Unsupported delta type: {type(chunk.delta)}")
|
|
337
|
+
elif isinstance(chunk, RawContentBlockStopEvent):
|
|
338
|
+
# Nothing to do on stop; tool call is considered assembled once stop occurs
|
|
339
|
+
pass
|
|
340
|
+
elif isinstance(chunk, RawMessageDeltaEvent):
|
|
341
|
+
if stop_reason := chunk.delta.stop_reason:
|
|
342
|
+
cur_finish_reason = FINISH_REASON_MAP.get(stop_reason)
|
|
343
|
+
# These token counts are cumulative
|
|
344
|
+
usage = UsageMetrics(
|
|
345
|
+
input=chunk.usage.input_tokens,
|
|
346
|
+
output=chunk.usage.output_tokens,
|
|
347
|
+
cache_read=chunk.usage.cache_read_input_tokens,
|
|
348
|
+
cache_write=chunk.usage.cache_creation_input_tokens,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
completions: list[LLMCompletionPartial] = []
|
|
352
|
+
completions.append(
|
|
353
|
+
LLMCompletionPartial(
|
|
354
|
+
text=cur_text,
|
|
355
|
+
tool_calls=cur_tool_calls,
|
|
356
|
+
reasoning_tokens=cur_reasoning_tokens,
|
|
357
|
+
finish_reason=cur_finish_reason,
|
|
358
|
+
)
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
assert cur_model is not None, "First chunk should always set the cur_model"
|
|
362
|
+
return LLMOutputPartial(
|
|
363
|
+
completions=completions, # type: ignore[arg-type]
|
|
364
|
+
model=cur_model,
|
|
365
|
+
usage=usage,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
@backoff.on_exception(
|
|
370
|
+
backoff.expo,
|
|
371
|
+
exception=(Exception),
|
|
372
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
373
|
+
max_tries=5,
|
|
374
|
+
factor=3.0,
|
|
375
|
+
on_backoff=_print_backoff_message,
|
|
376
|
+
)
|
|
377
|
+
async def get_anthropic_chat_completion_async(
|
|
378
|
+
client: AsyncAnthropic,
|
|
379
|
+
messages: list[ChatMessage],
|
|
380
|
+
model_name: str,
|
|
381
|
+
tools: list[ToolInfo] | None = None,
|
|
382
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
383
|
+
max_new_tokens: int = 32,
|
|
384
|
+
temperature: float = 1.0,
|
|
385
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
386
|
+
logprobs: bool = False,
|
|
387
|
+
top_logprobs: int | None = None,
|
|
388
|
+
timeout: float = 5.0,
|
|
389
|
+
) -> LLMOutput:
|
|
390
|
+
"""
|
|
391
|
+
Note from kevin 1/29/2025:
|
|
392
|
+
logprobs and top_logprobs were recently added to the OpenAI endpoint,
|
|
393
|
+
which broke some of my code. I'm just adding it to Anthropic as well, to maintain
|
|
394
|
+
"compatibility".
|
|
395
|
+
|
|
396
|
+
We should actually implement this at some point, but it does not work.
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
if logprobs or top_logprobs is not None:
|
|
400
|
+
raise NotImplementedError(
|
|
401
|
+
"We have not implemented logprobs or top_logprobs for Anthropic yet."
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
system, input_messages = parse_chat_messages(messages)
|
|
405
|
+
input_tools = parse_tools(tools) if tools else NOT_GIVEN
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
async with async_timeout_ctx(timeout):
|
|
409
|
+
raw_output = await client.messages.create(
|
|
410
|
+
model=model_name,
|
|
411
|
+
messages=input_messages,
|
|
412
|
+
thinking=(
|
|
413
|
+
{
|
|
414
|
+
"type": "enabled",
|
|
415
|
+
"budget_tokens": reasoning_budget(max_new_tokens, reasoning_effort),
|
|
416
|
+
}
|
|
417
|
+
if reasoning_effort
|
|
418
|
+
else NOT_GIVEN
|
|
419
|
+
),
|
|
420
|
+
tools=input_tools,
|
|
421
|
+
tool_choice=_parse_tool_choice(tool_choice) or NOT_GIVEN,
|
|
422
|
+
max_tokens=max_new_tokens,
|
|
423
|
+
temperature=temperature,
|
|
424
|
+
system=system if system is not None else NOT_GIVEN,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
output = parse_anthropic_completion(raw_output, model_name)
|
|
428
|
+
if output.first and output.first.finish_reason == "length" and output.first.no_text:
|
|
429
|
+
raise CompletionTooLongException(
|
|
430
|
+
"Completion empty due to truncation. Consider increasing max_new_tokens."
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return output
|
|
434
|
+
except (RateLimitError, BadRequestError) as e:
|
|
435
|
+
if e2 := _convert_anthropic_error(e):
|
|
436
|
+
raise e2 from e
|
|
437
|
+
else:
|
|
438
|
+
raise
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def get_anthropic_client_async(api_key: str | None = None) -> AsyncAnthropic:
|
|
442
|
+
return AsyncAnthropic(api_key=api_key) if api_key else AsyncAnthropic()
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def parse_anthropic_completion(message: Message | None, model: str) -> LLMOutput:
|
|
446
|
+
if message is None:
|
|
447
|
+
return LLMOutput(
|
|
448
|
+
model=model,
|
|
449
|
+
completions=[],
|
|
450
|
+
errors=[NoResponseException()],
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if message.stop_reason == "end_turn":
|
|
454
|
+
finish_reason = "stop"
|
|
455
|
+
elif message.stop_reason == "max_tokens":
|
|
456
|
+
finish_reason = "length"
|
|
457
|
+
elif message.stop_reason == "stop_sequence":
|
|
458
|
+
finish_reason = "stop"
|
|
459
|
+
elif message.stop_reason == "tool_use":
|
|
460
|
+
finish_reason = "tool_calls"
|
|
461
|
+
elif message.stop_reason == "refusal":
|
|
462
|
+
finish_reason = "refusal"
|
|
463
|
+
else:
|
|
464
|
+
finish_reason = "error"
|
|
465
|
+
|
|
466
|
+
text = None
|
|
467
|
+
tool_calls: list[ToolCall] = []
|
|
468
|
+
reasoning_tokens = None
|
|
469
|
+
for block in message.content:
|
|
470
|
+
if block.type == "text":
|
|
471
|
+
if text is not None:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
"Anthropic API returned multiple text blocks; this was unexpected."
|
|
474
|
+
)
|
|
475
|
+
text = block.text
|
|
476
|
+
elif block.type == "tool_use":
|
|
477
|
+
tool_calls.append(
|
|
478
|
+
ToolCall(
|
|
479
|
+
id=block.id,
|
|
480
|
+
function=block.name,
|
|
481
|
+
arguments=cast(dict[str, Any], block.input),
|
|
482
|
+
type="function",
|
|
483
|
+
)
|
|
484
|
+
)
|
|
485
|
+
elif block.type == "thinking":
|
|
486
|
+
reasoning_tokens = block.thinking
|
|
487
|
+
else:
|
|
488
|
+
raise ValueError(f"Unknown block type: {block.type}")
|
|
489
|
+
|
|
490
|
+
usage = UsageMetrics(
|
|
491
|
+
input=message.usage.input_tokens,
|
|
492
|
+
output=message.usage.output_tokens,
|
|
493
|
+
cache_read=message.usage.cache_read_input_tokens,
|
|
494
|
+
cache_write=message.usage.cache_creation_input_tokens,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
return LLMOutput(
|
|
498
|
+
model=model,
|
|
499
|
+
completions=[
|
|
500
|
+
LLMCompletion(
|
|
501
|
+
text=text,
|
|
502
|
+
tool_calls=tool_calls,
|
|
503
|
+
reasoning_tokens=reasoning_tokens,
|
|
504
|
+
finish_reason=finish_reason, # type: ignore
|
|
505
|
+
)
|
|
506
|
+
],
|
|
507
|
+
usage=usage,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
async def is_anthropic_api_key_valid(api_key: str) -> bool:
|
|
512
|
+
"""
|
|
513
|
+
Test whether an Anthropic API key is valid or invalid.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
api_key: The Anthropic API key to test.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
bool: True if the API key is valid, False otherwise.
|
|
520
|
+
"""
|
|
521
|
+
client = AsyncAnthropic(api_key=api_key)
|
|
522
|
+
|
|
523
|
+
try:
|
|
524
|
+
# Attempt to make a simple API call with minimal tokens/cost
|
|
525
|
+
await client.messages.create(
|
|
526
|
+
model="claude-3-haiku-20240307",
|
|
527
|
+
max_tokens=1,
|
|
528
|
+
messages=[{"role": "user", "content": "hi"}],
|
|
529
|
+
)
|
|
530
|
+
return True
|
|
531
|
+
except AuthenticationError:
|
|
532
|
+
# API key is invalid
|
|
533
|
+
return False
|
|
534
|
+
except Exception:
|
|
535
|
+
# Any other error means the key might be valid but there's another issue
|
|
536
|
+
# For testing key validity specifically, we'll return False only for auth errors
|
|
537
|
+
return True
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import Any, AsyncIterator, Literal, cast
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@asynccontextmanager
|
|
8
|
+
async def async_timeout_ctx(timeout: float | None) -> AsyncIterator[None]:
|
|
9
|
+
if timeout:
|
|
10
|
+
async with asyncio.timeout(timeout):
|
|
11
|
+
yield
|
|
12
|
+
else:
|
|
13
|
+
# No-op async contextmanager
|
|
14
|
+
yield
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def reasoning_budget(max_new_tokens: int, effort: Literal["low", "medium", "high"]) -> int:
|
|
18
|
+
if effort == "high":
|
|
19
|
+
ratio = 0.75
|
|
20
|
+
elif effort == "medium":
|
|
21
|
+
ratio = 0.5
|
|
22
|
+
else:
|
|
23
|
+
ratio = 0.25
|
|
24
|
+
return int(max_new_tokens * ratio)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def coerce_tool_args(args: Any) -> dict[str, Any]:
|
|
28
|
+
if isinstance(args, dict):
|
|
29
|
+
return cast(dict[str, Any], args)
|
|
30
|
+
if isinstance(args, str):
|
|
31
|
+
try:
|
|
32
|
+
loaded = json.loads(args)
|
|
33
|
+
return (
|
|
34
|
+
cast(dict[str, Any], loaded)
|
|
35
|
+
if isinstance(loaded, dict)
|
|
36
|
+
else {"__parse_error_raw_args": args}
|
|
37
|
+
)
|
|
38
|
+
except Exception:
|
|
39
|
+
return {"__parse_error_raw_args": args}
|
|
40
|
+
# Fallback: unknown structure
|
|
41
|
+
return {"__parse_error_raw_args": str(args)}
|