docent-python 0.1.20a0__py3-none-any.whl → 0.1.21a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

@@ -0,0 +1,537 @@
1
+ from typing import Any, Literal, cast
2
+
3
+ import backoff
4
+
5
+ # all errors: https://docs.anthropic.com/en/api/errors
6
+ from anthropic import (
7
+ AsyncAnthropic,
8
+ AuthenticationError,
9
+ BadRequestError,
10
+ NotFoundError,
11
+ PermissionDeniedError,
12
+ RateLimitError,
13
+ UnprocessableEntityError,
14
+ )
15
+ from anthropic._types import NOT_GIVEN
16
+ from anthropic.types import (
17
+ InputJSONDelta,
18
+ Message,
19
+ MessageParam,
20
+ RawContentBlockDeltaEvent,
21
+ RawContentBlockStartEvent,
22
+ RawContentBlockStopEvent,
23
+ RawMessageDeltaEvent,
24
+ RawMessageStartEvent,
25
+ RawMessageStreamEvent,
26
+ SignatureDelta,
27
+ TextBlockParam,
28
+ TextDelta,
29
+ ThinkingDelta,
30
+ ToolChoiceAnyParam,
31
+ ToolChoiceAutoParam,
32
+ ToolChoiceParam,
33
+ ToolParam,
34
+ ToolResultBlockParam,
35
+ ToolUseBlockParam,
36
+ )
37
+ from backoff.types import Details
38
+
39
+ from docent._llm_util.data_models.exceptions import (
40
+ CompletionTooLongException,
41
+ ContextWindowException,
42
+ NoResponseException,
43
+ RateLimitException,
44
+ )
45
+ from docent._llm_util.data_models.llm_output import (
46
+ AsyncSingleLLMOutputStreamingCallback,
47
+ FinishReasonType,
48
+ LLMCompletion,
49
+ LLMCompletionPartial,
50
+ LLMOutput,
51
+ LLMOutputPartial,
52
+ ToolCallPartial,
53
+ UsageMetrics,
54
+ finalize_llm_output_partial,
55
+ )
56
+ from docent._llm_util.providers.common import (
57
+ async_timeout_ctx,
58
+ reasoning_budget,
59
+ )
60
+ from docent._log_util import get_logger
61
+ from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
62
+
63
+ logger = get_logger(__name__)
64
+
65
+
66
+ def _print_backoff_message(e: Details):
67
+ logger.warning(
68
+ f"Anthropic backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
69
+ )
70
+
71
+
72
+ def _is_retryable_error(e: BaseException) -> bool:
73
+ if (
74
+ isinstance(e, BadRequestError)
75
+ or isinstance(e, ContextWindowException)
76
+ or isinstance(e, AuthenticationError)
77
+ or isinstance(e, NotImplementedError)
78
+ or isinstance(e, PermissionDeniedError)
79
+ or isinstance(e, NotFoundError)
80
+ or isinstance(e, UnprocessableEntityError)
81
+ ):
82
+ return False
83
+ return True
84
+
85
+
86
+ def _parse_message_content(content: str | list[Content]) -> str | list[TextBlockParam]:
87
+ if isinstance(content, str):
88
+ return content
89
+ else:
90
+ result: list[TextBlockParam] = []
91
+ for sub_content in content:
92
+ if sub_content.type == "text":
93
+ result.append(TextBlockParam(text=sub_content.text, type="text"))
94
+ else:
95
+ raise ValueError(f"Unsupported content type: {sub_content.type}")
96
+ return result
97
+
98
+
99
+ def parse_chat_messages(messages: list[ChatMessage]) -> tuple[str | None, list[MessageParam]]:
100
+ result: list[MessageParam] = []
101
+ system_prompt: str | None = None
102
+
103
+ for message in messages:
104
+ if message.role == "user":
105
+ result.append(
106
+ MessageParam(
107
+ role=message.role,
108
+ content=_parse_message_content(message.content),
109
+ )
110
+ )
111
+ elif message.role == "assistant":
112
+ message_content = _parse_message_content(message.content)
113
+ # Build content list without creating empty text blocks
114
+ if isinstance(message_content, str):
115
+ stripped = message_content.strip()
116
+ all_content = cast(
117
+ list[TextBlockParam | ToolUseBlockParam],
118
+ ([TextBlockParam(text=stripped, type="text")] if stripped else []),
119
+ )
120
+ else:
121
+ all_content = cast(list[TextBlockParam | ToolUseBlockParam], message_content)
122
+ for tool_call in message.tool_calls or []:
123
+ all_content.append(
124
+ ToolUseBlockParam(
125
+ id=tool_call.id,
126
+ input=tool_call.arguments,
127
+ name=tool_call.function,
128
+ type="tool_use",
129
+ )
130
+ )
131
+ result.append(
132
+ MessageParam(
133
+ role="assistant",
134
+ content=all_content,
135
+ )
136
+ )
137
+ elif message.role == "tool":
138
+ result.append(
139
+ MessageParam(
140
+ role="user",
141
+ content=[
142
+ ToolResultBlockParam(
143
+ tool_use_id=str(message.tool_call_id),
144
+ type="tool_result",
145
+ content=_parse_message_content(message.content),
146
+ is_error=message.error is not None,
147
+ )
148
+ ],
149
+ )
150
+ )
151
+ elif message.role == "system":
152
+ system_prompt = message.text
153
+ else:
154
+ raise ValueError(f"Unknown message role: {message.role}")
155
+
156
+ return system_prompt, result
157
+
158
+
159
+ def parse_tools(tools: list[ToolInfo]) -> list[ToolParam]:
160
+ return [
161
+ ToolParam(
162
+ name=tool.name,
163
+ description=tool.description,
164
+ input_schema=tool.parameters.model_dump(exclude_none=True),
165
+ )
166
+ for tool in tools
167
+ ]
168
+
169
+
170
+ def _parse_tool_choice(tool_choice: Literal["auto", "required"] | None) -> ToolChoiceParam | None:
171
+ if tool_choice is None:
172
+ return None
173
+ elif tool_choice == "auto":
174
+ return ToolChoiceAutoParam(type="auto")
175
+ elif tool_choice == "required":
176
+ return ToolChoiceAnyParam(type="any")
177
+
178
+
179
+ def _convert_anthropic_error(e: Exception):
180
+ if isinstance(e, BadRequestError):
181
+ if "context limit" in e.message.lower():
182
+ return ContextWindowException()
183
+ if isinstance(e, RateLimitError):
184
+ return RateLimitException(e)
185
+ return None
186
+
187
+
188
+ @backoff.on_exception(
189
+ backoff.expo,
190
+ exception=(Exception),
191
+ giveup=lambda e: not _is_retryable_error(e),
192
+ max_tries=5,
193
+ factor=3.0,
194
+ on_backoff=_print_backoff_message,
195
+ )
196
+ async def get_anthropic_chat_completion_streaming_async(
197
+ client: AsyncAnthropic,
198
+ streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
199
+ messages: list[ChatMessage],
200
+ model_name: str,
201
+ tools: list[ToolInfo] | None = None,
202
+ tool_choice: Literal["auto", "required"] | None = None,
203
+ max_new_tokens: int = 32,
204
+ temperature: float = 1.0,
205
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
206
+ logprobs: bool = False,
207
+ top_logprobs: int | None = None,
208
+ timeout: float = 5.0,
209
+ ):
210
+ if logprobs or top_logprobs is not None:
211
+ raise NotImplementedError(
212
+ "We have not implemented logprobs or top_logprobs for Anthropic yet."
213
+ )
214
+
215
+ system, input_messages = parse_chat_messages(messages)
216
+ input_tools = parse_tools(tools) if tools else NOT_GIVEN
217
+
218
+ try:
219
+ async with async_timeout_ctx(timeout):
220
+ stream = await client.messages.create(
221
+ model=model_name,
222
+ messages=input_messages,
223
+ thinking=(
224
+ {
225
+ "type": "enabled",
226
+ "budget_tokens": reasoning_budget(max_new_tokens, reasoning_effort),
227
+ }
228
+ if reasoning_effort
229
+ else NOT_GIVEN
230
+ ),
231
+ tools=input_tools,
232
+ tool_choice=_parse_tool_choice(tool_choice) or NOT_GIVEN,
233
+ max_tokens=max_new_tokens,
234
+ temperature=temperature,
235
+ system=system if system is not None else NOT_GIVEN,
236
+ stream=True,
237
+ )
238
+
239
+ llm_output_partial = None
240
+ async for chunk in stream:
241
+ llm_output_partial = update_llm_output(llm_output_partial, chunk)
242
+ if streaming_callback:
243
+ await streaming_callback(finalize_llm_output_partial(llm_output_partial))
244
+
245
+ # Fully parse the partial output
246
+ if llm_output_partial:
247
+ return finalize_llm_output_partial(llm_output_partial)
248
+ else:
249
+ # Streaming did not produce anything
250
+ return LLMOutput(model=model_name, completions=[], errors=[NoResponseException()])
251
+ except (RateLimitError, BadRequestError) as e:
252
+ if e2 := _convert_anthropic_error(e):
253
+ raise e2 from e
254
+ else:
255
+ raise
256
+
257
+
258
+ FINISH_REASON_MAP: dict[str, FinishReasonType] = {
259
+ "end_turn": "stop",
260
+ "max_tokens": "length",
261
+ "stop_sequence": "stop",
262
+ "tool_use": "tool_calls",
263
+ "refusal": "refusal",
264
+ }
265
+
266
+
267
+ def update_llm_output(
268
+ llm_output_partial: LLMOutputPartial | None,
269
+ chunk: RawMessageStreamEvent,
270
+ ):
271
+ """
272
+ Note that Anthropic only allows one message to be streamed at a time.
273
+ Thus there can only be one completion.
274
+ """
275
+
276
+ usage: UsageMetrics = llm_output_partial.usage if llm_output_partial else UsageMetrics()
277
+
278
+ if llm_output_partial is not None:
279
+ cur_text: str | None = llm_output_partial.completions[0].text
280
+ cur_reasoning_tokens: str | None = llm_output_partial.completions[0].reasoning_tokens
281
+ cur_finish_reason: FinishReasonType | None = llm_output_partial.completions[0].finish_reason
282
+ cur_tool_calls: list[ToolCallPartial | None] | None = llm_output_partial.completions[0].tool_calls # type: ignore[assignment]
283
+ cur_model = llm_output_partial.model
284
+ else:
285
+ cur_text, cur_reasoning_tokens, cur_finish_reason, cur_model = None, None, None, None
286
+ cur_tool_calls = None
287
+
288
+ if isinstance(chunk, RawMessageStartEvent):
289
+ cur_model = chunk.message.model
290
+ elif isinstance(chunk, RawContentBlockStartEvent):
291
+ # If a tool_use block starts, initialize a ToolCallPartial slot using the block index
292
+ content_block = chunk.content_block
293
+ if content_block.type == "tool_use":
294
+ # Ensure the tool_calls array exists and is long enough
295
+ index = chunk.index
296
+ cur_tool_calls = cur_tool_calls or []
297
+ if index >= len(cur_tool_calls):
298
+ cur_tool_calls.extend([None] * (index - len(cur_tool_calls) + 1))
299
+
300
+ # Initialize the partial with id/name; arguments will stream via InputJSONDelta
301
+ cur_tool_calls[index] = ToolCallPartial(
302
+ id=content_block.id,
303
+ function=content_block.name,
304
+ arguments_raw="",
305
+ type="function",
306
+ )
307
+ elif isinstance(chunk, RawContentBlockDeltaEvent):
308
+ if isinstance(chunk.delta, TextDelta):
309
+ cur_text = (cur_text or "") + chunk.delta.text
310
+ elif isinstance(chunk.delta, ThinkingDelta):
311
+ cur_reasoning_tokens = (cur_reasoning_tokens or "") + chunk.delta.thinking
312
+ elif isinstance(chunk.delta, InputJSONDelta):
313
+ # Append streamed JSON into the corresponding ToolCallPartial
314
+ index = chunk.index
315
+ if (
316
+ cur_tool_calls is None
317
+ or index >= len(cur_tool_calls)
318
+ or cur_tool_calls[index] is None
319
+ ):
320
+ # This should not happen with a well-behaved API, log and skip
321
+ logger.warning(
322
+ f"Received InputJSONDelta before start event at index {index}, skipping"
323
+ )
324
+ else:
325
+ cur_tool_calls[index] = ToolCallPartial(
326
+ id=cur_tool_calls[index].id, # type: ignore[union-attr]
327
+ function=cur_tool_calls[index].function, # type: ignore[union-attr]
328
+ arguments_raw=(cur_tool_calls[index].arguments_raw or "") + chunk.delta.partial_json, # type: ignore[union-attr]
329
+ type="function",
330
+ )
331
+ elif isinstance(chunk.delta, SignatureDelta):
332
+ logger.debug(
333
+ "Anthropic streamed thinking signature block; we should support this soon."
334
+ )
335
+ else:
336
+ raise ValueError(f"Unsupported delta type: {type(chunk.delta)}")
337
+ elif isinstance(chunk, RawContentBlockStopEvent):
338
+ # Nothing to do on stop; tool call is considered assembled once stop occurs
339
+ pass
340
+ elif isinstance(chunk, RawMessageDeltaEvent):
341
+ if stop_reason := chunk.delta.stop_reason:
342
+ cur_finish_reason = FINISH_REASON_MAP.get(stop_reason)
343
+ # These token counts are cumulative
344
+ usage = UsageMetrics(
345
+ input=chunk.usage.input_tokens,
346
+ output=chunk.usage.output_tokens,
347
+ cache_read=chunk.usage.cache_read_input_tokens,
348
+ cache_write=chunk.usage.cache_creation_input_tokens,
349
+ )
350
+
351
+ completions: list[LLMCompletionPartial] = []
352
+ completions.append(
353
+ LLMCompletionPartial(
354
+ text=cur_text,
355
+ tool_calls=cur_tool_calls,
356
+ reasoning_tokens=cur_reasoning_tokens,
357
+ finish_reason=cur_finish_reason,
358
+ )
359
+ )
360
+
361
+ assert cur_model is not None, "First chunk should always set the cur_model"
362
+ return LLMOutputPartial(
363
+ completions=completions, # type: ignore[arg-type]
364
+ model=cur_model,
365
+ usage=usage,
366
+ )
367
+
368
+
369
+ @backoff.on_exception(
370
+ backoff.expo,
371
+ exception=(Exception),
372
+ giveup=lambda e: not _is_retryable_error(e),
373
+ max_tries=5,
374
+ factor=3.0,
375
+ on_backoff=_print_backoff_message,
376
+ )
377
+ async def get_anthropic_chat_completion_async(
378
+ client: AsyncAnthropic,
379
+ messages: list[ChatMessage],
380
+ model_name: str,
381
+ tools: list[ToolInfo] | None = None,
382
+ tool_choice: Literal["auto", "required"] | None = None,
383
+ max_new_tokens: int = 32,
384
+ temperature: float = 1.0,
385
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
386
+ logprobs: bool = False,
387
+ top_logprobs: int | None = None,
388
+ timeout: float = 5.0,
389
+ ) -> LLMOutput:
390
+ """
391
+ Note from kevin 1/29/2025:
392
+ logprobs and top_logprobs were recently added to the OpenAI endpoint,
393
+ which broke some of my code. I'm just adding it to Anthropic as well, to maintain
394
+ "compatibility".
395
+
396
+ We should actually implement this at some point, but it does not work.
397
+ """
398
+
399
+ if logprobs or top_logprobs is not None:
400
+ raise NotImplementedError(
401
+ "We have not implemented logprobs or top_logprobs for Anthropic yet."
402
+ )
403
+
404
+ system, input_messages = parse_chat_messages(messages)
405
+ input_tools = parse_tools(tools) if tools else NOT_GIVEN
406
+
407
+ try:
408
+ async with async_timeout_ctx(timeout):
409
+ raw_output = await client.messages.create(
410
+ model=model_name,
411
+ messages=input_messages,
412
+ thinking=(
413
+ {
414
+ "type": "enabled",
415
+ "budget_tokens": reasoning_budget(max_new_tokens, reasoning_effort),
416
+ }
417
+ if reasoning_effort
418
+ else NOT_GIVEN
419
+ ),
420
+ tools=input_tools,
421
+ tool_choice=_parse_tool_choice(tool_choice) or NOT_GIVEN,
422
+ max_tokens=max_new_tokens,
423
+ temperature=temperature,
424
+ system=system if system is not None else NOT_GIVEN,
425
+ )
426
+
427
+ output = parse_anthropic_completion(raw_output, model_name)
428
+ if output.first and output.first.finish_reason == "length" and output.first.no_text:
429
+ raise CompletionTooLongException(
430
+ "Completion empty due to truncation. Consider increasing max_new_tokens."
431
+ )
432
+
433
+ return output
434
+ except (RateLimitError, BadRequestError) as e:
435
+ if e2 := _convert_anthropic_error(e):
436
+ raise e2 from e
437
+ else:
438
+ raise
439
+
440
+
441
+ def get_anthropic_client_async(api_key: str | None = None) -> AsyncAnthropic:
442
+ return AsyncAnthropic(api_key=api_key) if api_key else AsyncAnthropic()
443
+
444
+
445
+ def parse_anthropic_completion(message: Message | None, model: str) -> LLMOutput:
446
+ if message is None:
447
+ return LLMOutput(
448
+ model=model,
449
+ completions=[],
450
+ errors=[NoResponseException()],
451
+ )
452
+
453
+ if message.stop_reason == "end_turn":
454
+ finish_reason = "stop"
455
+ elif message.stop_reason == "max_tokens":
456
+ finish_reason = "length"
457
+ elif message.stop_reason == "stop_sequence":
458
+ finish_reason = "stop"
459
+ elif message.stop_reason == "tool_use":
460
+ finish_reason = "tool_calls"
461
+ elif message.stop_reason == "refusal":
462
+ finish_reason = "refusal"
463
+ else:
464
+ finish_reason = "error"
465
+
466
+ text = None
467
+ tool_calls: list[ToolCall] = []
468
+ reasoning_tokens = None
469
+ for block in message.content:
470
+ if block.type == "text":
471
+ if text is not None:
472
+ raise ValueError(
473
+ "Anthropic API returned multiple text blocks; this was unexpected."
474
+ )
475
+ text = block.text
476
+ elif block.type == "tool_use":
477
+ tool_calls.append(
478
+ ToolCall(
479
+ id=block.id,
480
+ function=block.name,
481
+ arguments=cast(dict[str, Any], block.input),
482
+ type="function",
483
+ )
484
+ )
485
+ elif block.type == "thinking":
486
+ reasoning_tokens = block.thinking
487
+ else:
488
+ raise ValueError(f"Unknown block type: {block.type}")
489
+
490
+ usage = UsageMetrics(
491
+ input=message.usage.input_tokens,
492
+ output=message.usage.output_tokens,
493
+ cache_read=message.usage.cache_read_input_tokens,
494
+ cache_write=message.usage.cache_creation_input_tokens,
495
+ )
496
+
497
+ return LLMOutput(
498
+ model=model,
499
+ completions=[
500
+ LLMCompletion(
501
+ text=text,
502
+ tool_calls=tool_calls,
503
+ reasoning_tokens=reasoning_tokens,
504
+ finish_reason=finish_reason, # type: ignore
505
+ )
506
+ ],
507
+ usage=usage,
508
+ )
509
+
510
+
511
+ async def is_anthropic_api_key_valid(api_key: str) -> bool:
512
+ """
513
+ Test whether an Anthropic API key is valid or invalid.
514
+
515
+ Args:
516
+ api_key: The Anthropic API key to test.
517
+
518
+ Returns:
519
+ bool: True if the API key is valid, False otherwise.
520
+ """
521
+ client = AsyncAnthropic(api_key=api_key)
522
+
523
+ try:
524
+ # Attempt to make a simple API call with minimal tokens/cost
525
+ await client.messages.create(
526
+ model="claude-3-haiku-20240307",
527
+ max_tokens=1,
528
+ messages=[{"role": "user", "content": "hi"}],
529
+ )
530
+ return True
531
+ except AuthenticationError:
532
+ # API key is invalid
533
+ return False
534
+ except Exception:
535
+ # Any other error means the key might be valid but there's another issue
536
+ # For testing key validity specifically, we'll return False only for auth errors
537
+ return True
@@ -0,0 +1,41 @@
1
+ import asyncio
2
+ import json
3
+ from contextlib import asynccontextmanager
4
+ from typing import Any, AsyncIterator, Literal, cast
5
+
6
+
7
+ @asynccontextmanager
8
+ async def async_timeout_ctx(timeout: float | None) -> AsyncIterator[None]:
9
+ if timeout:
10
+ async with asyncio.timeout(timeout):
11
+ yield
12
+ else:
13
+ # No-op async contextmanager
14
+ yield
15
+
16
+
17
+ def reasoning_budget(max_new_tokens: int, effort: Literal["low", "medium", "high"]) -> int:
18
+ if effort == "high":
19
+ ratio = 0.75
20
+ elif effort == "medium":
21
+ ratio = 0.5
22
+ else:
23
+ ratio = 0.25
24
+ return int(max_new_tokens * ratio)
25
+
26
+
27
+ def coerce_tool_args(args: Any) -> dict[str, Any]:
28
+ if isinstance(args, dict):
29
+ return cast(dict[str, Any], args)
30
+ if isinstance(args, str):
31
+ try:
32
+ loaded = json.loads(args)
33
+ return (
34
+ cast(dict[str, Any], loaded)
35
+ if isinstance(loaded, dict)
36
+ else {"__parse_error_raw_args": args}
37
+ )
38
+ except Exception:
39
+ return {"__parse_error_raw_args": args}
40
+ # Fallback: unknown structure
41
+ return {"__parse_error_raw_args": str(args)}