docent-python 0.1.20a0__py3-none-any.whl → 0.1.21a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +320 -0
- docent/_llm_util/data_models/simple_svc.py +79 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/model_registry.py +126 -0
- docent/_llm_util/prod_llms.py +454 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/transcript.py +2 -0
- docent/judges/__init__.py +21 -0
- docent/judges/impl.py +222 -0
- docent/judges/types.py +240 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +84 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +95 -0
- docent/judges/util/voting.py +84 -0
- docent/trace.py +1 -1
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.21a0.dist-info}/METADATA +6 -1
- docent_python-0.1.21a0.dist-info/RECORD +58 -0
- docent_python-0.1.20a0.dist-info/RECORD +0 -34
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.21a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.21a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""OpenRouter provider implementation using aiohttp library."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import aiohttp
|
|
8
|
+
import backoff
|
|
9
|
+
from backoff.types import Details
|
|
10
|
+
|
|
11
|
+
from docent._llm_util.data_models.exceptions import (
|
|
12
|
+
CompletionTooLongException,
|
|
13
|
+
ContextWindowException,
|
|
14
|
+
NoResponseException,
|
|
15
|
+
RateLimitException,
|
|
16
|
+
)
|
|
17
|
+
from docent._llm_util.data_models.llm_output import (
|
|
18
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
19
|
+
LLMCompletion,
|
|
20
|
+
LLMOutput,
|
|
21
|
+
UsageMetrics,
|
|
22
|
+
)
|
|
23
|
+
from docent._log_util import get_logger
|
|
24
|
+
from docent.data_models.chat import ChatMessage, Content, ToolCall, ToolInfo
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpenRouterClient:
|
|
32
|
+
"""Async client for OpenRouter API using aiohttp."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, api_key: str | None = None):
|
|
35
|
+
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
36
|
+
self.base_url = OPENROUTER_API_BASE
|
|
37
|
+
|
|
38
|
+
def _get_headers(self) -> dict[str, str]:
|
|
39
|
+
return {
|
|
40
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
41
|
+
"Content-Type": "application/json",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
async def chat_completions_create(
|
|
45
|
+
self,
|
|
46
|
+
model: str,
|
|
47
|
+
messages: list[dict[str, Any]],
|
|
48
|
+
tools: list[dict[str, Any]] | None = None,
|
|
49
|
+
tool_choice: str | None = None,
|
|
50
|
+
max_tokens: int = 32,
|
|
51
|
+
temperature: float = 1.0,
|
|
52
|
+
timeout: float = 30.0,
|
|
53
|
+
) -> dict[str, Any]:
|
|
54
|
+
"""Make an async chat completion request."""
|
|
55
|
+
url = f"{self.base_url}/chat/completions"
|
|
56
|
+
|
|
57
|
+
payload: dict[str, Any] = {
|
|
58
|
+
"model": model,
|
|
59
|
+
"messages": messages,
|
|
60
|
+
"max_tokens": max_tokens,
|
|
61
|
+
"temperature": temperature,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if tools:
|
|
65
|
+
payload["tools"] = tools
|
|
66
|
+
if tool_choice:
|
|
67
|
+
payload["tool_choice"] = tool_choice
|
|
68
|
+
|
|
69
|
+
async with aiohttp.ClientSession() as session:
|
|
70
|
+
async with session.post(
|
|
71
|
+
url,
|
|
72
|
+
json=payload,
|
|
73
|
+
headers=self._get_headers(),
|
|
74
|
+
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
75
|
+
) as response:
|
|
76
|
+
if response.status != 200:
|
|
77
|
+
try:
|
|
78
|
+
error_data: dict[str, Any] = await response.json()
|
|
79
|
+
error_msg: Any = error_data.get("error", {}).get(
|
|
80
|
+
"message", await response.text()
|
|
81
|
+
)
|
|
82
|
+
except Exception:
|
|
83
|
+
error_msg = await response.text()
|
|
84
|
+
if response.status == 429:
|
|
85
|
+
raise RateLimitException(f"OpenRouter rate limit: {error_msg}")
|
|
86
|
+
elif response.status == 400 and "context" in str(error_msg).lower():
|
|
87
|
+
raise ContextWindowException()
|
|
88
|
+
else:
|
|
89
|
+
raise Exception(f"OpenRouter API error ({response.status}): {error_msg}")
|
|
90
|
+
|
|
91
|
+
return await response.json()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_openrouter_client_async(api_key: str | None = None) -> OpenRouterClient:
|
|
95
|
+
return OpenRouterClient(api_key=api_key)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _print_backoff_message(e: Details):
|
|
99
|
+
logger.warning(
|
|
100
|
+
f"OpenRouter backing off for {e['wait']:.2f}s due to {e['exception'].__class__.__name__}" # type: ignore
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _is_retryable_error(e: BaseException) -> bool:
|
|
105
|
+
if isinstance(e, RateLimitException):
|
|
106
|
+
return True
|
|
107
|
+
if isinstance(e, ContextWindowException):
|
|
108
|
+
return False
|
|
109
|
+
if isinstance(e, aiohttp.ClientError):
|
|
110
|
+
return True
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _parse_message_content(
|
|
115
|
+
content: str | list[Content],
|
|
116
|
+
) -> str | list[dict[str, str]]:
|
|
117
|
+
if isinstance(content, str):
|
|
118
|
+
return content
|
|
119
|
+
else:
|
|
120
|
+
result: list[dict[str, str]] = []
|
|
121
|
+
for sub_content in content:
|
|
122
|
+
if sub_content.type == "text":
|
|
123
|
+
result.append({"type": "text", "text": sub_content.text})
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unsupported content type: {sub_content.type}")
|
|
126
|
+
return result
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def parse_chat_messages(messages: list[ChatMessage]) -> list[dict[str, Any]]:
|
|
130
|
+
"""Convert ChatMessage list to OpenRouter format."""
|
|
131
|
+
result: list[dict[str, Any]] = []
|
|
132
|
+
|
|
133
|
+
for message in messages:
|
|
134
|
+
if message.role == "user":
|
|
135
|
+
result.append(
|
|
136
|
+
{
|
|
137
|
+
"role": "user",
|
|
138
|
+
"content": _parse_message_content(message.content),
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
elif message.role == "assistant":
|
|
142
|
+
msg: dict[str, Any] = {
|
|
143
|
+
"role": "assistant",
|
|
144
|
+
"content": _parse_message_content(message.content),
|
|
145
|
+
}
|
|
146
|
+
if message.tool_calls:
|
|
147
|
+
msg["tool_calls"] = [
|
|
148
|
+
{
|
|
149
|
+
"id": tc.id,
|
|
150
|
+
"type": "function",
|
|
151
|
+
"function": {
|
|
152
|
+
"name": tc.function,
|
|
153
|
+
"arguments": json.dumps(tc.arguments),
|
|
154
|
+
},
|
|
155
|
+
}
|
|
156
|
+
for tc in message.tool_calls
|
|
157
|
+
]
|
|
158
|
+
result.append(msg)
|
|
159
|
+
elif message.role == "tool":
|
|
160
|
+
result.append(
|
|
161
|
+
{
|
|
162
|
+
"role": "tool",
|
|
163
|
+
"content": _parse_message_content(message.content),
|
|
164
|
+
"tool_call_id": str(message.tool_call_id),
|
|
165
|
+
}
|
|
166
|
+
)
|
|
167
|
+
elif message.role == "system":
|
|
168
|
+
result.append(
|
|
169
|
+
{
|
|
170
|
+
"role": "system",
|
|
171
|
+
"content": _parse_message_content(message.content),
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return result
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def parse_tools(tools: list[ToolInfo]) -> list[dict[str, Any]]:
|
|
179
|
+
"""Convert ToolInfo objects to OpenRouter format."""
|
|
180
|
+
result: list[dict[str, Any]] = []
|
|
181
|
+
|
|
182
|
+
for tool in tools:
|
|
183
|
+
result.append(
|
|
184
|
+
{
|
|
185
|
+
"type": "function",
|
|
186
|
+
"function": {
|
|
187
|
+
"name": tool.name,
|
|
188
|
+
"description": tool.description,
|
|
189
|
+
"parameters": tool.parameters.model_dump(exclude_none=True),
|
|
190
|
+
},
|
|
191
|
+
}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return result
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _parse_openrouter_tool_call(tc: dict[str, Any]) -> ToolCall:
|
|
198
|
+
"""Parse tool call from OpenRouter response."""
|
|
199
|
+
if tc.get("type") != "function":
|
|
200
|
+
return ToolCall(
|
|
201
|
+
id=tc.get("id", "unknown"),
|
|
202
|
+
function="unknown",
|
|
203
|
+
arguments={},
|
|
204
|
+
parse_error=f"Unsupported tool call type: {tc.get('type')}",
|
|
205
|
+
type=None,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
function_data = tc.get("function", {})
|
|
209
|
+
arguments: dict[str, Any] = {}
|
|
210
|
+
try:
|
|
211
|
+
arguments = json.loads(function_data.get("arguments", "{}"))
|
|
212
|
+
parse_error = None
|
|
213
|
+
except Exception as e:
|
|
214
|
+
arguments = {"__parse_error_raw_args": function_data.get("arguments", "")}
|
|
215
|
+
parse_error = f"Couldn't parse tool call arguments as JSON: {e}"
|
|
216
|
+
|
|
217
|
+
return ToolCall(
|
|
218
|
+
id=tc.get("id", "unknown"),
|
|
219
|
+
function=function_data.get("name", "unknown"),
|
|
220
|
+
arguments=arguments,
|
|
221
|
+
parse_error=parse_error,
|
|
222
|
+
type="function",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def parse_openrouter_completion(response: dict[str, Any], model: str) -> LLMOutput:
|
|
227
|
+
"""Parse OpenRouter completion response."""
|
|
228
|
+
choices = response.get("choices", [])
|
|
229
|
+
if not choices:
|
|
230
|
+
return LLMOutput(
|
|
231
|
+
model=model,
|
|
232
|
+
completions=[],
|
|
233
|
+
errors=[NoResponseException()],
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
usage_data = response.get("usage", {})
|
|
237
|
+
usage = UsageMetrics(
|
|
238
|
+
input=usage_data.get("prompt_tokens", 0),
|
|
239
|
+
output=usage_data.get("completion_tokens", 0),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
completions: list[LLMCompletion] = []
|
|
243
|
+
for choice in choices:
|
|
244
|
+
message = choice.get("message", {})
|
|
245
|
+
tool_calls_data = message.get("tool_calls")
|
|
246
|
+
|
|
247
|
+
completions.append(
|
|
248
|
+
LLMCompletion(
|
|
249
|
+
text=message.get("content"),
|
|
250
|
+
finish_reason=choice.get("finish_reason"),
|
|
251
|
+
tool_calls=(
|
|
252
|
+
[_parse_openrouter_tool_call(tc) for tc in tool_calls_data]
|
|
253
|
+
if tool_calls_data
|
|
254
|
+
else None
|
|
255
|
+
),
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
return LLMOutput(
|
|
260
|
+
model=response.get("model", model),
|
|
261
|
+
completions=completions,
|
|
262
|
+
usage=usage,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@backoff.on_exception(
|
|
267
|
+
backoff.expo,
|
|
268
|
+
exception=(Exception,),
|
|
269
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
270
|
+
max_tries=5,
|
|
271
|
+
factor=3.0,
|
|
272
|
+
on_backoff=_print_backoff_message,
|
|
273
|
+
)
|
|
274
|
+
async def get_openrouter_chat_completion_async(
|
|
275
|
+
client: OpenRouterClient,
|
|
276
|
+
messages: list[ChatMessage],
|
|
277
|
+
model_name: str,
|
|
278
|
+
tools: list[ToolInfo] | None = None,
|
|
279
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
280
|
+
max_new_tokens: int = 32,
|
|
281
|
+
temperature: float = 1.0,
|
|
282
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
283
|
+
logprobs: bool = False,
|
|
284
|
+
top_logprobs: int | None = None,
|
|
285
|
+
timeout: float = 30.0,
|
|
286
|
+
) -> LLMOutput:
|
|
287
|
+
"""Get completion from OpenRouter."""
|
|
288
|
+
if logprobs or top_logprobs is not None:
|
|
289
|
+
raise NotImplementedError(
|
|
290
|
+
"We have not implemented logprobs or top_logprobs for OpenRouter yet."
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
if reasoning_effort is not None:
|
|
294
|
+
logger.warning("OpenRouter does not support reasoning_effort parameter, ignoring.")
|
|
295
|
+
|
|
296
|
+
input_messages = parse_chat_messages(messages)
|
|
297
|
+
input_tools = parse_tools(tools) if tools else None
|
|
298
|
+
|
|
299
|
+
response = await client.chat_completions_create(
|
|
300
|
+
model=model_name,
|
|
301
|
+
messages=input_messages,
|
|
302
|
+
tools=input_tools,
|
|
303
|
+
tool_choice=tool_choice,
|
|
304
|
+
max_tokens=max_new_tokens,
|
|
305
|
+
temperature=temperature,
|
|
306
|
+
timeout=timeout,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
output = parse_openrouter_completion(response, model_name)
|
|
310
|
+
|
|
311
|
+
if output.first and output.first.finish_reason == "length" and output.first.no_text:
|
|
312
|
+
raise CompletionTooLongException(
|
|
313
|
+
"Completion empty due to truncation. Consider increasing max_new_tokens."
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return output
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@backoff.on_exception(
|
|
320
|
+
backoff.expo,
|
|
321
|
+
exception=(Exception,),
|
|
322
|
+
giveup=lambda e: not _is_retryable_error(e),
|
|
323
|
+
max_tries=5,
|
|
324
|
+
factor=3.0,
|
|
325
|
+
on_backoff=_print_backoff_message,
|
|
326
|
+
)
|
|
327
|
+
async def get_openrouter_chat_completion_streaming_async(
|
|
328
|
+
client: OpenRouterClient,
|
|
329
|
+
streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
|
|
330
|
+
messages: list[ChatMessage],
|
|
331
|
+
model_name: str,
|
|
332
|
+
tools: list[ToolInfo] | None = None,
|
|
333
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
334
|
+
max_new_tokens: int = 32,
|
|
335
|
+
temperature: float = 1.0,
|
|
336
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
337
|
+
logprobs: bool = False,
|
|
338
|
+
top_logprobs: int | None = None,
|
|
339
|
+
timeout: float = 30.0,
|
|
340
|
+
) -> LLMOutput:
|
|
341
|
+
"""Get streaming completion from OpenRouter (falls back to non-streaming)."""
|
|
342
|
+
logger.warning("Streaming not yet implemented for OpenRouter, using non-streaming.")
|
|
343
|
+
|
|
344
|
+
return await get_openrouter_chat_completion_async(
|
|
345
|
+
client=client,
|
|
346
|
+
messages=messages,
|
|
347
|
+
model_name=model_name,
|
|
348
|
+
tools=tools,
|
|
349
|
+
tool_choice=tool_choice,
|
|
350
|
+
max_new_tokens=max_new_tokens,
|
|
351
|
+
temperature=temperature,
|
|
352
|
+
reasoning_effort=reasoning_effort,
|
|
353
|
+
logprobs=logprobs,
|
|
354
|
+
top_logprobs=top_logprobs,
|
|
355
|
+
timeout=timeout,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
async def is_openrouter_api_key_valid(api_key: str) -> bool:
|
|
360
|
+
"""Test whether an OpenRouter API key is valid."""
|
|
361
|
+
client = OpenRouterClient(api_key=api_key)
|
|
362
|
+
|
|
363
|
+
try:
|
|
364
|
+
# Make a minimal request to test the key
|
|
365
|
+
await client.chat_completions_create(
|
|
366
|
+
model="openai/gpt-5-nano",
|
|
367
|
+
messages=[{"role": "user", "content": "hi"}],
|
|
368
|
+
max_tokens=1,
|
|
369
|
+
timeout=10.0,
|
|
370
|
+
)
|
|
371
|
+
return True
|
|
372
|
+
except Exception as e:
|
|
373
|
+
if "authentication" in str(e).lower() or "authorization" in str(e).lower():
|
|
374
|
+
return False
|
|
375
|
+
raise
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Provides preferences of which LLM models to use for different Docent functions."""
|
|
2
|
+
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from docent._llm_util.model_registry import get_context_window
|
|
9
|
+
from docent._log_util import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelOption(BaseModel):
|
|
15
|
+
"""Configuration for a specific model from a provider. Not to be confused with ModelInfo.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
provider: The name of the LLM provider (e.g., "openai", "anthropic").
|
|
19
|
+
model_name: The specific model to use from the provider.
|
|
20
|
+
reasoning_effort: Optional indication of computational effort to use.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
provider: str
|
|
24
|
+
model_name: str
|
|
25
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelOptionWithContext(BaseModel):
|
|
29
|
+
"""Enhanced model option that includes context window information for frontend use.
|
|
30
|
+
Not to be confused with ModelInfo or ModelOption.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
provider: The name of the LLM provider (e.g., "openai", "anthropic").
|
|
34
|
+
model_name: The specific model to use from the provider.
|
|
35
|
+
reasoning_effort: Optional indication of computational effort to use.
|
|
36
|
+
context_window: The context window size in tokens.
|
|
37
|
+
uses_byok: Whether this model would use the user's own API key.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
provider: str
|
|
41
|
+
model_name: str
|
|
42
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
43
|
+
context_window: int
|
|
44
|
+
uses_byok: bool
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_model_option(
|
|
48
|
+
cls, model_option: ModelOption, uses_byok: bool = False
|
|
49
|
+
) -> "ModelOptionWithContext":
|
|
50
|
+
"""Create a ModelOptionWithContext from a ModelOption.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model_option: The base model option
|
|
54
|
+
uses_byok: Whether this model requires bring-your-own-key
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
ModelOptionWithContext with context window looked up from global mapping
|
|
58
|
+
"""
|
|
59
|
+
context_window = get_context_window(model_option.model_name)
|
|
60
|
+
|
|
61
|
+
return cls(
|
|
62
|
+
provider=model_option.provider,
|
|
63
|
+
model_name=model_option.model_name,
|
|
64
|
+
reasoning_effort=model_option.reasoning_effort,
|
|
65
|
+
context_window=context_window,
|
|
66
|
+
uses_byok=uses_byok,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def merge_models_with_byok(
|
|
71
|
+
defaults: list[ModelOption],
|
|
72
|
+
byok: list[ModelOption],
|
|
73
|
+
api_keys: dict[str, str] | None,
|
|
74
|
+
) -> list[ModelOptionWithContext]:
|
|
75
|
+
user_keys = api_keys or {}
|
|
76
|
+
|
|
77
|
+
merged: list[ModelOption] = list(defaults)
|
|
78
|
+
if user_keys:
|
|
79
|
+
merged.extend([m for m in byok if m.provider in user_keys])
|
|
80
|
+
|
|
81
|
+
return [ModelOptionWithContext.from_model_option(m, m.provider in user_keys) for m in merged]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class PublicProviderPreferences(BaseModel):
|
|
85
|
+
@cached_property
|
|
86
|
+
def default_judge_models(self) -> list[ModelOption]:
|
|
87
|
+
"""Judge models that any user can access without providing their own API key"""
|
|
88
|
+
|
|
89
|
+
return [
|
|
90
|
+
ModelOption(provider="openai", model_name="gpt-5", reasoning_effort="medium"),
|
|
91
|
+
ModelOption(provider="openai", model_name="gpt-5", reasoning_effort="low"),
|
|
92
|
+
ModelOption(provider="openai", model_name="gpt-5", reasoning_effort="high"),
|
|
93
|
+
ModelOption(provider="openai", model_name="gpt-5-mini", reasoning_effort="low"),
|
|
94
|
+
ModelOption(provider="openai", model_name="gpt-5-mini", reasoning_effort="medium"),
|
|
95
|
+
ModelOption(provider="openai", model_name="gpt-5-mini", reasoning_effort="high"),
|
|
96
|
+
ModelOption(
|
|
97
|
+
provider="anthropic",
|
|
98
|
+
model_name="claude-sonnet-4-20250514",
|
|
99
|
+
reasoning_effort="medium",
|
|
100
|
+
),
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
PUBLIC_PROVIDER_PREFERENCES = PublicProviderPreferences()
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Registry for LLM providers with their configurations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable, Literal, Protocol, TypedDict
|
|
6
|
+
|
|
7
|
+
from docent._llm_util.data_models.llm_output import (
|
|
8
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
9
|
+
LLMOutput,
|
|
10
|
+
)
|
|
11
|
+
from docent._llm_util.providers import anthropic, google, openai, openrouter
|
|
12
|
+
from docent._llm_util.providers.anthropic import (
|
|
13
|
+
get_anthropic_chat_completion_async,
|
|
14
|
+
get_anthropic_chat_completion_streaming_async,
|
|
15
|
+
)
|
|
16
|
+
from docent._llm_util.providers.google import (
|
|
17
|
+
get_google_chat_completion_async,
|
|
18
|
+
get_google_chat_completion_streaming_async,
|
|
19
|
+
)
|
|
20
|
+
from docent._llm_util.providers.openai import (
|
|
21
|
+
get_openai_chat_completion_async,
|
|
22
|
+
get_openai_chat_completion_streaming_async,
|
|
23
|
+
)
|
|
24
|
+
from docent._llm_util.providers.openrouter import (
|
|
25
|
+
get_openrouter_chat_completion_async,
|
|
26
|
+
get_openrouter_chat_completion_streaming_async,
|
|
27
|
+
)
|
|
28
|
+
from docent.data_models.chat import ChatMessage, ToolInfo
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SingleOutputGetter(Protocol):
|
|
32
|
+
"""Protocol for getting non-streaming output from an LLM.
|
|
33
|
+
|
|
34
|
+
Defines the interface for async functions that retrieve a single
|
|
35
|
+
non-streaming response from an LLM provider.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
async def __call__(
|
|
39
|
+
self,
|
|
40
|
+
client: Any,
|
|
41
|
+
messages: list[ChatMessage],
|
|
42
|
+
model_name: str,
|
|
43
|
+
*,
|
|
44
|
+
tools: list[ToolInfo] | None,
|
|
45
|
+
tool_choice: Literal["auto", "required"] | None,
|
|
46
|
+
max_new_tokens: int,
|
|
47
|
+
temperature: float,
|
|
48
|
+
reasoning_effort: Literal["low", "medium", "high"] | None,
|
|
49
|
+
logprobs: bool,
|
|
50
|
+
top_logprobs: int | None,
|
|
51
|
+
timeout: float,
|
|
52
|
+
) -> LLMOutput:
|
|
53
|
+
"""Get a single completion from an LLM.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
client: The provider-specific client instance.
|
|
57
|
+
messages: The list of messages in the conversation.
|
|
58
|
+
model_name: The name of the model to use.
|
|
59
|
+
tools: Optional list of tools available to the model.
|
|
60
|
+
tool_choice: Optional specification for tool usage.
|
|
61
|
+
max_new_tokens: Maximum number of tokens to generate.
|
|
62
|
+
temperature: Controls randomness in output generation.
|
|
63
|
+
reasoning_effort: Optional control for model reasoning depth.
|
|
64
|
+
logprobs: Whether to return log probabilities.
|
|
65
|
+
top_logprobs: Number of most likely tokens to return probabilities for.
|
|
66
|
+
timeout: Maximum time to wait for a response in seconds.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
LLMOutput: The model's response.
|
|
70
|
+
"""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class SingleStreamingOutputGetter(Protocol):
|
|
75
|
+
"""Protocol for getting streaming output from an LLM.
|
|
76
|
+
|
|
77
|
+
Defines the interface for async functions that retrieve streaming
|
|
78
|
+
responses from an LLM provider.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
async def __call__(
|
|
82
|
+
self,
|
|
83
|
+
client: Any,
|
|
84
|
+
streaming_callback: AsyncSingleLLMOutputStreamingCallback | None,
|
|
85
|
+
messages: list[ChatMessage],
|
|
86
|
+
model_name: str,
|
|
87
|
+
*,
|
|
88
|
+
tools: list[ToolInfo] | None,
|
|
89
|
+
tool_choice: Literal["auto", "required"] | None,
|
|
90
|
+
max_new_tokens: int,
|
|
91
|
+
temperature: float,
|
|
92
|
+
reasoning_effort: Literal["low", "medium", "high"] | None,
|
|
93
|
+
logprobs: bool,
|
|
94
|
+
top_logprobs: int | None,
|
|
95
|
+
timeout: float,
|
|
96
|
+
) -> LLMOutput:
|
|
97
|
+
"""Get a streaming completion from an LLM.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
client: The provider-specific client instance.
|
|
101
|
+
streaming_callback: Optional callback for processing streaming chunks.
|
|
102
|
+
messages: The list of messages in the conversation.
|
|
103
|
+
model_name: The name of the model to use.
|
|
104
|
+
tools: Optional list of tools available to the model.
|
|
105
|
+
tool_choice: Optional specification for tool usage.
|
|
106
|
+
max_new_tokens: Maximum number of tokens to generate.
|
|
107
|
+
temperature: Controls randomness in output generation.
|
|
108
|
+
reasoning_effort: Optional control for model reasoning depth.
|
|
109
|
+
logprobs: Whether to return log probabilities.
|
|
110
|
+
top_logprobs: Number of most likely tokens to return probabilities for.
|
|
111
|
+
timeout: Maximum time to wait for a response in seconds.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
LLMOutput: The complete model response after streaming finishes.
|
|
115
|
+
"""
|
|
116
|
+
...
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ProviderConfig(TypedDict):
|
|
120
|
+
"""Configuration for an LLM provider.
|
|
121
|
+
|
|
122
|
+
Contains the necessary functions to create clients and interact with
|
|
123
|
+
a specific LLM provider.
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
async_client_getter: Function to get an async client for the provider.
|
|
127
|
+
single_output_getter: Function to get a non-streaming completion.
|
|
128
|
+
single_streaming_output_getter: Function to get a streaming completion.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
async_client_getter: Callable[[str | None], Any]
|
|
132
|
+
single_output_getter: SingleOutputGetter
|
|
133
|
+
single_streaming_output_getter: SingleStreamingOutputGetter
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Registry of supported LLM providers with their respective configurations
|
|
137
|
+
PROVIDERS: dict[str, ProviderConfig] = {
|
|
138
|
+
"anthropic": ProviderConfig(
|
|
139
|
+
async_client_getter=anthropic.get_anthropic_client_async,
|
|
140
|
+
single_output_getter=get_anthropic_chat_completion_async,
|
|
141
|
+
single_streaming_output_getter=get_anthropic_chat_completion_streaming_async,
|
|
142
|
+
),
|
|
143
|
+
"google": ProviderConfig(
|
|
144
|
+
async_client_getter=google.get_google_client_async,
|
|
145
|
+
single_output_getter=get_google_chat_completion_async,
|
|
146
|
+
single_streaming_output_getter=get_google_chat_completion_streaming_async,
|
|
147
|
+
),
|
|
148
|
+
"openai": ProviderConfig(
|
|
149
|
+
async_client_getter=openai.get_openai_client_async,
|
|
150
|
+
single_output_getter=get_openai_chat_completion_async,
|
|
151
|
+
single_streaming_output_getter=get_openai_chat_completion_streaming_async,
|
|
152
|
+
),
|
|
153
|
+
"azure_openai": ProviderConfig(
|
|
154
|
+
async_client_getter=openai.get_azure_openai_client_async,
|
|
155
|
+
single_output_getter=get_openai_chat_completion_async,
|
|
156
|
+
single_streaming_output_getter=get_openai_chat_completion_streaming_async,
|
|
157
|
+
),
|
|
158
|
+
"openrouter": ProviderConfig(
|
|
159
|
+
async_client_getter=openrouter.get_openrouter_client_async,
|
|
160
|
+
single_output_getter=get_openrouter_chat_completion_async,
|
|
161
|
+
single_streaming_output_getter=get_openrouter_chat_completion_streaming_async,
|
|
162
|
+
),
|
|
163
|
+
}
|
|
164
|
+
"""Registry of supported LLM providers with their respective configurations."""
|
docent/data_models/transcript.py
CHANGED
|
@@ -42,6 +42,8 @@ Important notes:
|
|
|
42
42
|
- Citations are self-contained. Do NOT label them as citation or evidence. Just insert the citation by itself at the appropriate place in the text.
|
|
43
43
|
- Citations must come immediately after the part of a claim that they support. This may be in the middle of a sentence.
|
|
44
44
|
- Each pair of brackets must contain only one citation. To cite multiple blocks, use multiple pairs of brackets, like [T0B0] [T0B1].
|
|
45
|
+
- Outside of citations, do not refer to transcript numbers or block numbers.
|
|
46
|
+
- Outside of citations, avoid quoting or paraphrasing the transcript.
|
|
45
47
|
"""
|
|
46
48
|
|
|
47
49
|
BLOCK_CITE_INSTRUCTION = """Each transcript and each block has a unique index. Cite the relevant indices in brackets when relevant, like [T<idx>B<idx>]. Use multiple tags to cite multiple blocks, like [T<idx1>B<idx1>][T<idx2>B<idx2>]. Remember to cite specific blocks and NOT action units."""
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from docent.judges.impl import BaseJudge, MajorityVotingJudge, MultiReflectionJudge
|
|
2
|
+
from docent.judges.types import (
|
|
3
|
+
JudgeResult,
|
|
4
|
+
JudgeResultCompletionCallback,
|
|
5
|
+
JudgeResultWithCitations,
|
|
6
|
+
ResultType,
|
|
7
|
+
Rubric,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
# Judges
|
|
12
|
+
"MajorityVotingJudge",
|
|
13
|
+
"MultiReflectionJudge",
|
|
14
|
+
"BaseJudge",
|
|
15
|
+
# Types
|
|
16
|
+
"Rubric",
|
|
17
|
+
"JudgeResult",
|
|
18
|
+
"JudgeResultWithCitations",
|
|
19
|
+
"JudgeResultCompletionCallback",
|
|
20
|
+
"ResultType",
|
|
21
|
+
]
|