docent-python 0.1.19a0__py3-none-any.whl → 0.1.21a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (34) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +320 -0
  5. docent/_llm_util/data_models/simple_svc.py +79 -0
  6. docent/_llm_util/llm_cache.py +193 -0
  7. docent/_llm_util/model_registry.py +126 -0
  8. docent/_llm_util/prod_llms.py +454 -0
  9. docent/_llm_util/providers/__init__.py +0 -0
  10. docent/_llm_util/providers/anthropic.py +537 -0
  11. docent/_llm_util/providers/common.py +41 -0
  12. docent/_llm_util/providers/google.py +530 -0
  13. docent/_llm_util/providers/openai.py +745 -0
  14. docent/_llm_util/providers/openrouter.py +375 -0
  15. docent/_llm_util/providers/preference_types.py +104 -0
  16. docent/_llm_util/providers/provider_registry.py +164 -0
  17. docent/data_models/transcript.py +2 -0
  18. docent/data_models/util.py +170 -0
  19. docent/judges/__init__.py +21 -0
  20. docent/judges/impl.py +222 -0
  21. docent/judges/types.py +240 -0
  22. docent/judges/util/forgiving_json.py +108 -0
  23. docent/judges/util/meta_schema.json +84 -0
  24. docent/judges/util/meta_schema.py +29 -0
  25. docent/judges/util/parse_output.py +95 -0
  26. docent/judges/util/voting.py +84 -0
  27. docent/sdk/client.py +5 -2
  28. docent/trace.py +1 -1
  29. docent/trace_2.py +1842 -0
  30. {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/METADATA +10 -5
  31. docent_python-0.1.21a0.dist-info/RECORD +58 -0
  32. docent_python-0.1.19a0.dist-info/RECORD +0 -32
  33. {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/WHEEL +0 -0
  34. {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,530 @@
1
+ from typing import Any, Literal, cast
2
+
3
+ import backoff
4
+ import requests
5
+ from backoff.types import Details
6
+ from google import genai
7
+ from google.genai import errors, types
8
+ from google.genai.client import AsyncClient as AsyncGoogle
9
+
10
+ from docent._llm_util.data_models.exceptions import (
11
+ CompletionTooLongException,
12
+ ContextWindowException,
13
+ NoResponseException,
14
+ RateLimitException,
15
+ )
16
+ from docent._llm_util.data_models.llm_output import (
17
+ AsyncSingleLLMOutputStreamingCallback,
18
+ LLMCompletion,
19
+ LLMOutput,
20
+ UsageMetrics,
21
+ )
22
+ from docent._llm_util.providers.common import (
23
+ async_timeout_ctx,
24
+ coerce_tool_args,
25
+ reasoning_budget,
26
+ )
27
+ from docent._log_util import get_logger
28
+ from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
29
+
30
+
31
+ def get_google_client_async(api_key: str | None = None) -> AsyncGoogle:
32
+ if api_key:
33
+ return genai.Client(api_key=api_key).aio
34
+ return genai.Client().aio
35
+
36
+
37
+ logger = get_logger(__name__)
38
+
39
+
40
+ def _convert_google_error(e: errors.APIError):
41
+ if e.code in [429, 502, 503, 504]:
42
+ return RateLimitException(e)
43
+ elif e.code == 400 and "maximum number of tokens" in str(e).lower():
44
+ return ContextWindowException()
45
+ return None
46
+
47
+
48
+ def _print_backoff_message(e: Details):
49
+ logger.warning(
50
+ f"Google backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
51
+ )
52
+
53
+
54
+ def _is_retryable_error(exception: BaseException) -> bool:
55
+ """Checks if the exception is a retryable error based on the criteria."""
56
+ if isinstance(exception, errors.APIError):
57
+ return exception.code in [429, 500, 502, 503, 504]
58
+ if isinstance(exception, requests.exceptions.ConnectionError):
59
+ return True
60
+ return False
61
+
62
+
63
+ @backoff.on_exception(
64
+ backoff.expo,
65
+ exception=(Exception),
66
+ giveup=lambda e: not _is_retryable_error(e),
67
+ max_tries=3,
68
+ factor=2.0,
69
+ on_backoff=_print_backoff_message,
70
+ )
71
+ async def get_google_chat_completion_async(
72
+ client: AsyncGoogle,
73
+ messages: list[ChatMessage],
74
+ model_name: str,
75
+ tools: list[ToolInfo] | None = None,
76
+ tool_choice: Literal["auto", "required"] | None = None,
77
+ max_new_tokens: int = 32,
78
+ temperature: float = 1.0,
79
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
80
+ logprobs: bool = False,
81
+ top_logprobs: int | None = None,
82
+ timeout: float = 5.0,
83
+ ) -> LLMOutput:
84
+ if logprobs or top_logprobs is not None:
85
+ raise NotImplementedError(
86
+ "We have not implemented logprobs or top_logprobs for Google yet."
87
+ )
88
+
89
+ system, input_messages = _parse_chat_messages(messages, tools_provided=bool(tools))
90
+
91
+ async with async_timeout_ctx(timeout):
92
+ thinking_cfg = None
93
+ if reasoning_effort:
94
+ thinking_cfg = types.ThinkingConfig(
95
+ include_thoughts=True,
96
+ thinking_budget=reasoning_budget(max_new_tokens, reasoning_effort),
97
+ )
98
+
99
+ raw_output = await client.models.generate_content( # type: ignore
100
+ model=model_name,
101
+ contents=input_messages, # type: ignore
102
+ config=types.GenerateContentConfig(
103
+ temperature=temperature,
104
+ thinking_config=thinking_cfg,
105
+ max_output_tokens=max_new_tokens,
106
+ system_instruction=system,
107
+ tools=_parse_tools(tools) if tools else None,
108
+ tool_config=(
109
+ types.ToolConfig(function_calling_config=_parse_tool_choice(tool_choice))
110
+ if tool_choice is not None
111
+ else None
112
+ ),
113
+ ),
114
+ )
115
+
116
+ output = _parse_google_completion(raw_output, model_name)
117
+ if output.first and output.first.finish_reason == "length" and output.first.no_text:
118
+ raise CompletionTooLongException(
119
+ f"Completion empty due to truncation. Consider increasing max_new_tokens (currently {max_new_tokens})."
120
+ )
121
+
122
+ return output
123
+
124
+
125
+ @backoff.on_exception(
126
+ backoff.expo,
127
+ exception=(Exception),
128
+ giveup=lambda e: not _is_retryable_error(e),
129
+ max_tries=3,
130
+ factor=2.0,
131
+ on_backoff=_print_backoff_message,
132
+ )
133
+ async def get_google_chat_completion_streaming_async(
134
+ client: AsyncGoogle,
135
+ streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
136
+ messages: list[ChatMessage],
137
+ model_name: str,
138
+ tools: list[ToolInfo] | None = None,
139
+ tool_choice: Literal["auto", "required"] | None = None,
140
+ max_new_tokens: int = 32,
141
+ temperature: float = 1.0,
142
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
143
+ logprobs: bool = False,
144
+ top_logprobs: int | None = None,
145
+ timeout: float = 5.0,
146
+ ) -> LLMOutput:
147
+ if logprobs or top_logprobs is not None:
148
+ raise NotImplementedError(
149
+ "We have not implemented logprobs or top_logprobs for Google yet."
150
+ )
151
+
152
+ system, input_messages = _parse_chat_messages(messages, tools_provided=bool(tools))
153
+
154
+ try:
155
+ async with async_timeout_ctx(timeout):
156
+ thinking_cfg = None
157
+ if reasoning_effort:
158
+ thinking_cfg = types.ThinkingConfig(
159
+ include_thoughts=True,
160
+ thinking_budget=reasoning_budget(max_new_tokens, reasoning_effort),
161
+ )
162
+
163
+ stream = await client.models.generate_content_stream( # type: ignore
164
+ model=model_name,
165
+ contents=input_messages, # type: ignore
166
+ config=types.GenerateContentConfig(
167
+ temperature=temperature,
168
+ thinking_config=thinking_cfg,
169
+ max_output_tokens=max_new_tokens,
170
+ system_instruction=system,
171
+ tools=_parse_tools(tools) if tools else None,
172
+ tool_config=(
173
+ types.ToolConfig(function_calling_config=_parse_tool_choice(tool_choice))
174
+ if tool_choice is not None
175
+ else None
176
+ ),
177
+ ),
178
+ )
179
+
180
+ accumulated_text = ""
181
+ accumulated_tool_calls: list[ToolCall] = []
182
+ finish_reason: str | None = None
183
+ usage = UsageMetrics()
184
+
185
+ async for chunk in stream:
186
+ candidate = chunk.candidates[0] if chunk.candidates else None
187
+ if candidate and candidate.content and candidate.content.parts:
188
+ for part in candidate.content.parts:
189
+ if part.text is not None and not part.thought:
190
+ accumulated_text += part.text or ""
191
+ elif part.function_call is not None:
192
+ fc = part.function_call
193
+ args = coerce_tool_args(getattr(fc, "args", {}))
194
+ accumulated_tool_calls.append(
195
+ ToolCall(
196
+ id=getattr(fc, "id", None)
197
+ or f"{getattr(fc, 'name', 'tool')}_call",
198
+ function=getattr(fc, "name", "unknown"),
199
+ arguments=args,
200
+ type="function",
201
+ )
202
+ )
203
+
204
+ if candidate and candidate.finish_reason is not None:
205
+ if candidate.finish_reason == types.FinishReason.STOP:
206
+ finish_reason = "stop"
207
+ elif candidate.finish_reason == types.FinishReason.MAX_TOKENS:
208
+ finish_reason = "length"
209
+ else:
210
+ finish_reason = "error"
211
+
212
+ # Check for usage metadata in the chunk
213
+ if usage_metadata := chunk.usage_metadata:
214
+ if usage_metadata.prompt_token_count is not None:
215
+ usage["input"] = int(usage_metadata.prompt_token_count)
216
+ if usage_metadata.candidates_token_count is not None:
217
+ usage["output"] = int(usage_metadata.candidates_token_count)
218
+
219
+ if streaming_callback is not None:
220
+ await streaming_callback(
221
+ LLMOutput(
222
+ model=model_name,
223
+ completions=[LLMCompletion(text=accumulated_text)],
224
+ )
225
+ )
226
+
227
+ return LLMOutput(
228
+ model=model_name,
229
+ completions=[
230
+ LLMCompletion(
231
+ text=accumulated_text,
232
+ tool_calls=(accumulated_tool_calls or None),
233
+ finish_reason=finish_reason,
234
+ )
235
+ ],
236
+ usage=usage,
237
+ )
238
+ except errors.APIError as e:
239
+ if e2 := _convert_google_error(e):
240
+ raise e2 from e
241
+ else:
242
+ raise
243
+
244
+
245
+ def _parse_chat_messages(
246
+ messages: list[ChatMessage],
247
+ *,
248
+ tools_provided: bool = False,
249
+ ) -> tuple[str | None, list[types.Content]]:
250
+ result: list[types.Content] = []
251
+ system_prompt: str | None = None
252
+
253
+ for message in messages:
254
+ if message.role == "user":
255
+ parts = _parse_message_content(message.content)
256
+ if parts: # Avoid sending empty text parts
257
+ result.append(
258
+ types.Content(
259
+ role="user",
260
+ parts=parts,
261
+ )
262
+ )
263
+ elif message.role == "assistant":
264
+ parts: list[types.Part] = _parse_message_content(message.content)
265
+ # If assistant previously made tool calls, include them so the model has full context
266
+ for tool_call in getattr(message, "tool_calls", []) or []:
267
+ try:
268
+ parts.append(
269
+ types.Part.from_function_call(
270
+ name=tool_call.function,
271
+ args=tool_call.arguments, # type: ignore[arg-type]
272
+ id=tool_call.id, # type: ignore[call-arg]
273
+ )
274
+ )
275
+ except Exception:
276
+ # Fallback without id if the SDK signature differs
277
+ parts.append(
278
+ types.Part.from_function_call(
279
+ name=tool_call.function,
280
+ args=tool_call.arguments, # type: ignore[arg-type]
281
+ )
282
+ )
283
+ if parts: # If only tool calls with no text, we still include function_call parts
284
+ result.append(types.Content(role="model", parts=parts))
285
+ elif getattr(message, "tool_calls", []):
286
+ # Include just the tool calls if present
287
+ result.append(types.Content(role="model", parts=parts))
288
+ elif message.role == "tool":
289
+ # Represent tool result as a function_response part (Gemini tool execution result)
290
+ if not tools_provided:
291
+ # If no tools configured, pass through as plain text
292
+ parts = _parse_message_content(message.content)
293
+ if parts:
294
+ result.append(types.Content(role="user", parts=parts))
295
+ else:
296
+ tool_name = getattr(message, "function", None) or "unknown_tool"
297
+ tool_id = getattr(message, "tool_call_id", None)
298
+ # Try to parse tool content as JSON if it looks like JSON; otherwise wrap as text
299
+ tool_text = message.text or ""
300
+ response_obj: dict[str, Any]
301
+ try:
302
+ import json as _json
303
+
304
+ parsed = _json.loads(tool_text)
305
+ if isinstance(parsed, dict): # type: ignore[redundant-cast]
306
+ response_obj = cast(dict[str, Any], parsed)
307
+ else:
308
+ response_obj = {"result": parsed}
309
+ except Exception:
310
+ response_obj = {"result": tool_text}
311
+
312
+ part = _make_function_response_part(name=tool_name, response=response_obj, id=tool_id) # type: ignore[arg-type]
313
+ result.append(types.Content(role="user", parts=[part]))
314
+ elif message.role == "system":
315
+ system_prompt = message.text
316
+ else:
317
+ raise ValueError(f"Unknown message role: {message.role}")
318
+
319
+ return system_prompt, result
320
+
321
+
322
+ def _parse_message_content(content: str | list[Content]) -> list[types.Part]:
323
+ if isinstance(content, str):
324
+ text = content.strip()
325
+ return [types.Part.from_text(text=text)] if text else []
326
+ else:
327
+ result: list[types.Part] = []
328
+ for sub_content in content:
329
+ if sub_content.type == "text":
330
+ txt = (sub_content.text or "").strip()
331
+ if txt:
332
+ result.append(types.Part.from_text(text=txt))
333
+ else:
334
+ raise ValueError(f"Unsupported content type: {sub_content.type}")
335
+ return result
336
+
337
+
338
+ def _parse_google_completion(message: types.GenerateContentResponse, model: str) -> LLMOutput:
339
+ if not message.candidates:
340
+ return LLMOutput(
341
+ model=model,
342
+ completions=[],
343
+ errors=[NoResponseException()],
344
+ )
345
+
346
+ candidate = message.candidates[0]
347
+
348
+ if candidate.finish_reason == types.FinishReason.STOP:
349
+ finish_reason = "stop"
350
+ elif candidate.finish_reason == types.FinishReason.MAX_TOKENS:
351
+ finish_reason = "length"
352
+ else:
353
+ finish_reason = "error"
354
+
355
+ text = ""
356
+ tool_calls: list[ToolCall] = []
357
+ content_parts = candidate.content.parts if candidate.content else []
358
+ content_parts = content_parts or []
359
+ for part in content_parts:
360
+ if part.text is not None and not part.thought:
361
+ text += part.text
362
+ elif part.thought:
363
+ logger.warning("Google returned thinking block; we should support this soon.")
364
+ elif getattr(part, "function_call", None) is not None:
365
+ fc = part.function_call
366
+ # Attempt to parse arguments as a dictionary
367
+ args = coerce_tool_args(getattr(fc, "args", {}))
368
+ tool_calls.append(
369
+ ToolCall(
370
+ id=getattr(fc, "id", None) or f"{getattr(fc, 'name', 'tool')}_call",
371
+ function=getattr(fc, "name", "unknown"),
372
+ arguments=args,
373
+ type="function",
374
+ )
375
+ )
376
+ else:
377
+ raise ValueError(f"Unknown content part: {part}")
378
+
379
+ # Extract usage metrics from the response
380
+ usage = UsageMetrics()
381
+ if usage_metadata := message.usage_metadata:
382
+ if usage_metadata.prompt_token_count is not None:
383
+ usage["input"] = int(usage_metadata.prompt_token_count)
384
+ if usage_metadata.candidates_token_count is not None:
385
+ usage["output"] = int(usage_metadata.candidates_token_count)
386
+
387
+ return LLMOutput(
388
+ model=model,
389
+ completions=[
390
+ LLMCompletion(
391
+ text=text,
392
+ finish_reason=("tool_calls" if tool_calls else finish_reason),
393
+ tool_calls=(tool_calls or None),
394
+ )
395
+ ],
396
+ usage=usage,
397
+ )
398
+
399
+
400
+ def _parse_tools(tools: list[ToolInfo]) -> list[types.Tool]:
401
+ # Gemini expects a list of Tool objects, each with one or more FunctionDeclarations
402
+ fds: list[types.FunctionDeclaration] = []
403
+ for tool in tools:
404
+ fds.append(
405
+ types.FunctionDeclaration(
406
+ name=tool.name,
407
+ description=tool.description,
408
+ parameters=_convert_toolparams_to_schema(tool.parameters),
409
+ )
410
+ )
411
+ # Group all function declarations into a single Tool for simplicity
412
+ return [types.Tool(function_declarations=fds)]
413
+
414
+
415
+ def _parse_tool_choice(tool_choice: Literal["auto", "required"] | None):
416
+ if tool_choice is None:
417
+ return None
418
+ # Map our values to SDK enum; if unavailable, return None so default behavior applies
419
+ try:
420
+ if tool_choice == "auto":
421
+ return types.FunctionCallingConfig(mode=types.FunctionCallingConfigMode.AUTO)
422
+ elif tool_choice == "required":
423
+ return types.FunctionCallingConfig(mode=types.FunctionCallingConfigMode.ANY)
424
+ except Exception:
425
+ return None
426
+
427
+
428
+ def _convert_toolparams_to_schema(params: Any) -> types.Schema:
429
+ properties: dict[str, types.Schema] = {}
430
+ params_props: dict[str, Any] = getattr(params, "properties", {}) or {}
431
+ for name, param in params_props.items():
432
+ prop_schema = _convert_json_schema_to_gemini_schema(
433
+ (getattr(param, "input_schema", {}) or {})
434
+ )
435
+ desc: Any = getattr(param, "description", None)
436
+ if desc and prop_schema.description is None:
437
+ prop_schema.description = desc
438
+ properties[str(name)] = prop_schema
439
+
440
+ required_names: list[str] | None = None
441
+ required_raw: Any = getattr(params, "required", None)
442
+ if isinstance(required_raw, list):
443
+ required_list: list[Any] = cast(list[Any], required_raw)
444
+ required_names = [str(item) for item in required_list]
445
+
446
+ return types.Schema(
447
+ type=types.Type.OBJECT,
448
+ properties=properties or None,
449
+ required=required_names,
450
+ )
451
+
452
+
453
+ def _convert_json_schema_to_gemini_schema(js: dict[str, Any]) -> types.Schema:
454
+ type_get: Any = js.get("type")
455
+ type_name: str
456
+ if isinstance(type_get, str):
457
+ type_name = type_get.lower()
458
+ elif isinstance(type_get, list):
459
+ # Convert list to list[str] then take first
460
+ type_list: list[str] = [str(v) for v in cast(list[Any], type_get)]
461
+ type_name = type_list[0].lower() if type_list else ""
462
+ elif type_get is None:
463
+ type_name = ""
464
+ else:
465
+ type_name = str(type_get).lower()
466
+ if type_name == "string":
467
+ t: types.Type | None = types.Type.STRING
468
+ elif type_name == "number":
469
+ t = types.Type.NUMBER
470
+ elif type_name == "integer":
471
+ t = types.Type.INTEGER
472
+ elif type_name == "boolean":
473
+ t = types.Type.BOOLEAN
474
+ elif type_name == "array":
475
+ t = types.Type.ARRAY
476
+ elif type_name == "object":
477
+ t = types.Type.OBJECT
478
+ elif type_name == "null":
479
+ t = types.Type.NULL
480
+ else:
481
+ t = None
482
+ description = js.get("description")
483
+ enum_vals_any: Any = js.get("enum")
484
+ enum_vals: list[str] | None = None
485
+ if isinstance(enum_vals_any, list):
486
+ enum_vals = [str(v) for v in cast(list[Any], enum_vals_any)] or None
487
+
488
+ props_in_raw_any: Any = js.get("properties") or {}
489
+ props_in_raw: dict[str, Any] = (
490
+ cast(dict[str, Any], props_in_raw_any) if isinstance(props_in_raw_any, dict) else {}
491
+ )
492
+ props_out: dict[str, types.Schema] | None = None
493
+ if props_in_raw:
494
+ tmp_props: dict[str, types.Schema] = {}
495
+ for key, val in props_in_raw.items():
496
+ if isinstance(val, dict):
497
+ tmp_props[str(key)] = _convert_json_schema_to_gemini_schema(
498
+ cast(dict[str, Any], val)
499
+ )
500
+ props_out = tmp_props if tmp_props else None
501
+
502
+ required_out: list[str] | None = None
503
+ required_raw_js: Any = js.get("required")
504
+ if isinstance(required_raw_js, list):
505
+ tmp_required_any: list[Any] = cast(list[Any], required_raw_js)
506
+ tmp_required: list[str] = [str(item) for item in tmp_required_any]
507
+ required_out = tmp_required or None
508
+
509
+ items_in_any: Any = js.get("items")
510
+ items_out: types.Schema | None = None
511
+ if isinstance(items_in_any, dict):
512
+ items_out = _convert_json_schema_to_gemini_schema(cast(dict[str, Any], items_in_any))
513
+
514
+ return types.Schema(
515
+ type=t,
516
+ description=description,
517
+ enum=enum_vals,
518
+ properties=props_out,
519
+ required=required_out,
520
+ items=items_out,
521
+ )
522
+
523
+
524
+ def _make_function_response_part(
525
+ *, name: str, response: dict[str, object], id: str | None
526
+ ) -> types.Part:
527
+ try:
528
+ return types.Part.from_function_response(name=name, response=response, id=id) # type: ignore[call-arg]
529
+ except Exception:
530
+ return types.Part.from_function_response(name=name, response=response)