docent-python 0.1.14a0__py3-none-any.whl → 0.1.28a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +130 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/__init__.py +2 -0
- docent/data_models/agent_run.py +17 -29
- docent/data_models/chat/__init__.py +6 -1
- docent/data_models/chat/message.py +3 -1
- docent/data_models/citation.py +103 -22
- docent/data_models/judge.py +19 -0
- docent/data_models/metadata_util.py +16 -0
- docent/data_models/remove_invalid_citation_ranges.py +23 -10
- docent/data_models/transcript.py +25 -80
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +311 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +87 -0
- docent/judges/util/voting.py +139 -0
- docent/sdk/agent_run_writer.py +72 -21
- docent/sdk/client.py +276 -23
- docent/trace.py +413 -90
- {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/METADATA +13 -5
- docent_python-0.1.28a0.dist-info/RECORD +59 -0
- docent/data_models/metadata.py +0 -229
- docent/data_models/yaml_util.py +0 -12
- docent_python-0.1.14a0.dist-info/RECORD +0 -32
- {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import traceback
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Literal,
|
|
7
|
+
Protocol,
|
|
8
|
+
Sequence,
|
|
9
|
+
cast,
|
|
10
|
+
runtime_checkable,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
import anyio
|
|
14
|
+
from anyio import Lock, Semaphore
|
|
15
|
+
from anyio.abc import TaskGroup
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
from docent._llm_util.data_models.exceptions import (
|
|
19
|
+
DocentUsageLimitException,
|
|
20
|
+
LLMException,
|
|
21
|
+
RateLimitException,
|
|
22
|
+
ValidationFailedException,
|
|
23
|
+
)
|
|
24
|
+
from docent._llm_util.data_models.llm_output import (
|
|
25
|
+
AsyncLLMOutputStreamingCallback,
|
|
26
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
27
|
+
LLMOutput,
|
|
28
|
+
)
|
|
29
|
+
from docent._llm_util.llm_cache import LLMCache
|
|
30
|
+
from docent._llm_util.providers.preference_types import ModelOption
|
|
31
|
+
from docent._llm_util.providers.provider_registry import (
|
|
32
|
+
PROVIDERS,
|
|
33
|
+
SingleOutputGetter,
|
|
34
|
+
SingleStreamingOutputGetter,
|
|
35
|
+
)
|
|
36
|
+
from docent._log_util import get_logger
|
|
37
|
+
from docent.data_models.chat import ChatMessage, ToolInfo, parse_chat_message
|
|
38
|
+
|
|
39
|
+
logger = get_logger(__name__)
|
|
40
|
+
|
|
41
|
+
MAX_VALIDATION_ATTEMPTS = 3
|
|
42
|
+
DEFAULT_MAX_CONCURRENCY = 100
|
|
43
|
+
DEFAULT_SVC_MAX_CONCURRENCY = 100
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@runtime_checkable
|
|
47
|
+
class MessageResolver(Protocol):
|
|
48
|
+
def __call__(self) -> list[ChatMessage | dict[str, Any]]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
MessagesInput = Sequence[ChatMessage | dict[str, Any]] | MessageResolver
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _resolve_messages_input(messages_input: MessagesInput) -> list[ChatMessage]:
|
|
55
|
+
raw_messages = (
|
|
56
|
+
messages_input() if isinstance(messages_input, MessageResolver) else messages_input
|
|
57
|
+
)
|
|
58
|
+
return [parse_chat_message(msg) for msg in raw_messages]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _get_single_streaming_callback(
|
|
62
|
+
batch_index: int,
|
|
63
|
+
streaming_callback: AsyncLLMOutputStreamingCallback,
|
|
64
|
+
) -> AsyncSingleLLMOutputStreamingCallback:
|
|
65
|
+
async def single_streaming_callback(llm_output: LLMOutput):
|
|
66
|
+
await streaming_callback(batch_index, llm_output)
|
|
67
|
+
|
|
68
|
+
return single_streaming_callback
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def _parallelize_calls(
|
|
72
|
+
single_output_getter: SingleOutputGetter | SingleStreamingOutputGetter,
|
|
73
|
+
streaming_callback: AsyncLLMOutputStreamingCallback | None,
|
|
74
|
+
validation_callback: AsyncLLMOutputStreamingCallback | None,
|
|
75
|
+
completion_callback: AsyncLLMOutputStreamingCallback | None,
|
|
76
|
+
# Arguments for the individual completion getter
|
|
77
|
+
client: Any,
|
|
78
|
+
inputs: Sequence[MessagesInput],
|
|
79
|
+
model_name: str,
|
|
80
|
+
tools: list[ToolInfo] | None,
|
|
81
|
+
tool_choice: Literal["auto", "required"] | None,
|
|
82
|
+
max_new_tokens: int,
|
|
83
|
+
temperature: float,
|
|
84
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None,
|
|
85
|
+
logprobs: bool,
|
|
86
|
+
top_logprobs: int | None,
|
|
87
|
+
timeout: float,
|
|
88
|
+
semaphore: Semaphore,
|
|
89
|
+
# use_tqdm: bool,
|
|
90
|
+
cache: LLMCache | None = None,
|
|
91
|
+
):
|
|
92
|
+
base_func = partial(
|
|
93
|
+
single_output_getter,
|
|
94
|
+
client=client,
|
|
95
|
+
model_name=model_name,
|
|
96
|
+
tools=tools,
|
|
97
|
+
tool_choice=tool_choice,
|
|
98
|
+
max_new_tokens=max_new_tokens,
|
|
99
|
+
temperature=temperature,
|
|
100
|
+
reasoning_effort=reasoning_effort,
|
|
101
|
+
logprobs=logprobs,
|
|
102
|
+
top_logprobs=top_logprobs,
|
|
103
|
+
timeout=timeout,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
responses: list[LLMOutput | None] = [None for _ in inputs]
|
|
107
|
+
pbar = (
|
|
108
|
+
tqdm(
|
|
109
|
+
total=len(inputs),
|
|
110
|
+
desc=f"Calling {model_name} (reasoning_effort={reasoning_effort}) API",
|
|
111
|
+
)
|
|
112
|
+
if len(inputs) > 1
|
|
113
|
+
else None
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Save resolved messages to avoid multiple resolutions
|
|
117
|
+
resolved_messages: list[list[ChatMessage] | None] = [None] * len(inputs)
|
|
118
|
+
|
|
119
|
+
# Not sure why the cast is necessary for the type checker
|
|
120
|
+
cancelled_due_to_usage_limit: bool = cast(bool, False)
|
|
121
|
+
|
|
122
|
+
async def _limited_task(i: int, cur_input: MessagesInput, tg: TaskGroup):
|
|
123
|
+
nonlocal responses, pbar, resolved_messages, cancelled_due_to_usage_limit
|
|
124
|
+
|
|
125
|
+
async with semaphore:
|
|
126
|
+
messages = _resolve_messages_input(cur_input)
|
|
127
|
+
resolved_messages[i] = messages
|
|
128
|
+
|
|
129
|
+
retry_count = 0
|
|
130
|
+
result = None
|
|
131
|
+
call_started_at: float | None = None
|
|
132
|
+
|
|
133
|
+
# Check if there's a cached result
|
|
134
|
+
cached_result = (
|
|
135
|
+
cache.get(
|
|
136
|
+
messages,
|
|
137
|
+
model_name,
|
|
138
|
+
tools=tools,
|
|
139
|
+
tool_choice=tool_choice,
|
|
140
|
+
reasoning_effort=reasoning_effort,
|
|
141
|
+
temperature=temperature,
|
|
142
|
+
logprobs=logprobs,
|
|
143
|
+
top_logprobs=top_logprobs,
|
|
144
|
+
)
|
|
145
|
+
if cache is not None
|
|
146
|
+
else None
|
|
147
|
+
)
|
|
148
|
+
if cached_result is not None:
|
|
149
|
+
result = cached_result
|
|
150
|
+
if streaming_callback is not None:
|
|
151
|
+
await streaming_callback(i, result)
|
|
152
|
+
else:
|
|
153
|
+
call_started_at = time.perf_counter()
|
|
154
|
+
while retry_count < MAX_VALIDATION_ATTEMPTS:
|
|
155
|
+
try:
|
|
156
|
+
if streaming_callback is None:
|
|
157
|
+
result = await base_func(client=client, messages=messages)
|
|
158
|
+
else:
|
|
159
|
+
result = await base_func(
|
|
160
|
+
client=client,
|
|
161
|
+
streaming_callback=_get_single_streaming_callback(
|
|
162
|
+
i, streaming_callback
|
|
163
|
+
),
|
|
164
|
+
messages=messages,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Validate if validation callback provided and result is successful
|
|
168
|
+
if validation_callback and not result.did_error:
|
|
169
|
+
await validation_callback(i, result)
|
|
170
|
+
|
|
171
|
+
break
|
|
172
|
+
except ValidationFailedException as e:
|
|
173
|
+
retry_count += 1
|
|
174
|
+
logger.warning(
|
|
175
|
+
f"Validation failed for {model_name} after {retry_count} attempts: {e}"
|
|
176
|
+
)
|
|
177
|
+
if retry_count >= MAX_VALIDATION_ATTEMPTS:
|
|
178
|
+
logger.error(
|
|
179
|
+
f"Validation failed for {model_name} after {retry_count} attempts. Original output: {e.failed_output}"
|
|
180
|
+
)
|
|
181
|
+
result = LLMOutput(
|
|
182
|
+
model=model_name,
|
|
183
|
+
completions=[],
|
|
184
|
+
errors=[e],
|
|
185
|
+
)
|
|
186
|
+
break
|
|
187
|
+
except DocentUsageLimitException as _:
|
|
188
|
+
result = LLMOutput(
|
|
189
|
+
model=model_name,
|
|
190
|
+
completions=[],
|
|
191
|
+
errors=[], # Usage limit exceptions will be added to all results later if cancelled_due_to_usage_limit
|
|
192
|
+
)
|
|
193
|
+
cancelled_due_to_usage_limit = True
|
|
194
|
+
tg.cancel_scope.cancel()
|
|
195
|
+
break
|
|
196
|
+
except Exception as e:
|
|
197
|
+
if not isinstance(e, LLMException):
|
|
198
|
+
logger.error(
|
|
199
|
+
f"LLM call raised an exception that is not an LLMException: {e}. Failure traceback:\n{traceback.format_exc()}"
|
|
200
|
+
)
|
|
201
|
+
llm_exception = LLMException(e)
|
|
202
|
+
llm_exception.__cause__ = e
|
|
203
|
+
else:
|
|
204
|
+
llm_exception = e
|
|
205
|
+
|
|
206
|
+
error_message = f"Call to {model_name} failed even with backoff: {e.__class__.__name__}."
|
|
207
|
+
|
|
208
|
+
if not isinstance(e, RateLimitException):
|
|
209
|
+
error_message += f" Failure traceback:\n{traceback.format_exc()}"
|
|
210
|
+
logger.error(error_message)
|
|
211
|
+
|
|
212
|
+
result = LLMOutput(
|
|
213
|
+
model=model_name,
|
|
214
|
+
completions=[],
|
|
215
|
+
errors=[llm_exception],
|
|
216
|
+
)
|
|
217
|
+
break
|
|
218
|
+
|
|
219
|
+
# Only store the elapsed time if we didn't hit the cache and the call was successful
|
|
220
|
+
if cached_result is None and result is not None and call_started_at is not None:
|
|
221
|
+
result.duration = time.perf_counter() - call_started_at
|
|
222
|
+
|
|
223
|
+
# Always call completion callback with final result (success or error)
|
|
224
|
+
if completion_callback and result is not None:
|
|
225
|
+
try:
|
|
226
|
+
await completion_callback(i, result)
|
|
227
|
+
# LLMService uses this callback to record cost, and may throw an error if we just exceeded limit
|
|
228
|
+
except DocentUsageLimitException as e:
|
|
229
|
+
result.errors.append(e)
|
|
230
|
+
cancelled_due_to_usage_limit = True
|
|
231
|
+
tg.cancel_scope.cancel()
|
|
232
|
+
|
|
233
|
+
responses[i] = result
|
|
234
|
+
if pbar is not None:
|
|
235
|
+
pbar.update(1)
|
|
236
|
+
if pbar is None or pbar.n == pbar.total:
|
|
237
|
+
tg.cancel_scope.cancel()
|
|
238
|
+
|
|
239
|
+
def _cache_responses():
|
|
240
|
+
nonlocal responses, cache
|
|
241
|
+
|
|
242
|
+
if cache is not None:
|
|
243
|
+
indices = [
|
|
244
|
+
i
|
|
245
|
+
for i, response in enumerate(responses)
|
|
246
|
+
if resolved_messages[i] is not None
|
|
247
|
+
and response is not None
|
|
248
|
+
and not response.did_error
|
|
249
|
+
]
|
|
250
|
+
cache.set_batch(
|
|
251
|
+
# We already checked that each index has a resolved messages list
|
|
252
|
+
[cast(list[ChatMessage], resolved_messages[i]) for i in indices],
|
|
253
|
+
model_name,
|
|
254
|
+
# We already checked that each index corresponds to an LLMOutput object
|
|
255
|
+
[cast(LLMOutput, responses[i]) for i in indices],
|
|
256
|
+
tools=tools,
|
|
257
|
+
tool_choice=tool_choice,
|
|
258
|
+
reasoning_effort=reasoning_effort,
|
|
259
|
+
temperature=temperature,
|
|
260
|
+
logprobs=logprobs,
|
|
261
|
+
top_logprobs=top_logprobs,
|
|
262
|
+
)
|
|
263
|
+
return len(indices)
|
|
264
|
+
else:
|
|
265
|
+
return 0
|
|
266
|
+
|
|
267
|
+
# Get all results concurrently
|
|
268
|
+
try:
|
|
269
|
+
async with anyio.create_task_group() as tg:
|
|
270
|
+
# Start all the individual tasks
|
|
271
|
+
for i, cur_input in enumerate(inputs):
|
|
272
|
+
tg.start_soon(_limited_task, i, cur_input, tg)
|
|
273
|
+
|
|
274
|
+
# Cache what we have so far if something got cancelled
|
|
275
|
+
except anyio.get_cancelled_exc_class():
|
|
276
|
+
num_cached = _cache_responses()
|
|
277
|
+
if num_cached:
|
|
278
|
+
logger.info(
|
|
279
|
+
f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls, but cached {num_cached} completed responses"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# If the task was cancelled due to usage limit, set the response to a usage limit exception
|
|
283
|
+
if cancelled_due_to_usage_limit:
|
|
284
|
+
for i, response in enumerate(responses):
|
|
285
|
+
if response is None:
|
|
286
|
+
responses[i] = LLMOutput(
|
|
287
|
+
model=model_name,
|
|
288
|
+
completions=[],
|
|
289
|
+
errors=[DocentUsageLimitException()],
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
response.errors.append(DocentUsageLimitException())
|
|
293
|
+
|
|
294
|
+
raise
|
|
295
|
+
|
|
296
|
+
# Cache results if available
|
|
297
|
+
_cache_responses()
|
|
298
|
+
|
|
299
|
+
# At this point, all indices should have a result
|
|
300
|
+
assert all(
|
|
301
|
+
isinstance(r, LLMOutput) for r in responses
|
|
302
|
+
), "Some indices were never set to an LLMOutput, which should never happen"
|
|
303
|
+
|
|
304
|
+
return cast(list[LLMOutput], responses)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class BaseLLMService:
|
|
308
|
+
def __init__(self, max_concurrency: int = DEFAULT_SVC_MAX_CONCURRENCY):
|
|
309
|
+
self.max_concurrency, self._semaphore = max_concurrency, Semaphore(max_concurrency)
|
|
310
|
+
self._client_cache: dict[tuple[str, str | None], Any] = {} # (provider, api_key) -> client
|
|
311
|
+
self._client_cache_lock = Lock()
|
|
312
|
+
|
|
313
|
+
async def _get_cached_client(self, provider: str, override_key: str | None) -> Any:
|
|
314
|
+
"""Return a cached client for the provider/api-key tuple, creating one if needed."""
|
|
315
|
+
cache_key = (provider, override_key)
|
|
316
|
+
async with self._client_cache_lock:
|
|
317
|
+
cached = self._client_cache.get(cache_key)
|
|
318
|
+
if cached is not None:
|
|
319
|
+
return cached
|
|
320
|
+
|
|
321
|
+
client_factory = PROVIDERS[provider]["async_client_getter"]
|
|
322
|
+
new_client = client_factory(override_key)
|
|
323
|
+
self._client_cache[cache_key] = new_client
|
|
324
|
+
return new_client
|
|
325
|
+
|
|
326
|
+
async def get_completions(
|
|
327
|
+
self,
|
|
328
|
+
*,
|
|
329
|
+
inputs: Sequence[MessagesInput],
|
|
330
|
+
model_options: list[ModelOption],
|
|
331
|
+
tools: list[ToolInfo] | None = None,
|
|
332
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
333
|
+
max_new_tokens: int = 1024,
|
|
334
|
+
temperature: float = 1.0,
|
|
335
|
+
logprobs: bool = False,
|
|
336
|
+
top_logprobs: int | None = None,
|
|
337
|
+
timeout: float = 120.0,
|
|
338
|
+
streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
339
|
+
validation_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
340
|
+
completion_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
341
|
+
use_cache: bool = False,
|
|
342
|
+
_api_key_overrides: dict[str, str] = dict(),
|
|
343
|
+
) -> list[LLMOutput]:
|
|
344
|
+
"""Request completions from a configured LLM provider."""
|
|
345
|
+
|
|
346
|
+
# We don't support logprobs for Anthropic yet
|
|
347
|
+
if logprobs:
|
|
348
|
+
for model_option in model_options:
|
|
349
|
+
if model_option.provider == "anthropic":
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Instantiate cache
|
|
355
|
+
# TODO(mengk): make this more robust, possibly move to a NoSQL database or something
|
|
356
|
+
try:
|
|
357
|
+
cache = LLMCache() if use_cache else None
|
|
358
|
+
except ValueError as e:
|
|
359
|
+
logger.warning(f"Disabling LLM cache due to init error: {e}")
|
|
360
|
+
cache = None
|
|
361
|
+
|
|
362
|
+
# Initialize pointer to which model we're using; used for model rotation after failures
|
|
363
|
+
current_model_option_index = 0
|
|
364
|
+
|
|
365
|
+
def _rotate_model_option() -> ModelOption | None:
|
|
366
|
+
nonlocal current_model_option_index
|
|
367
|
+
|
|
368
|
+
current_model_option_index += 1
|
|
369
|
+
if current_model_option_index >= len(model_options):
|
|
370
|
+
logger.error("All model options are exhausted")
|
|
371
|
+
return None
|
|
372
|
+
|
|
373
|
+
new_model_option = model_options[current_model_option_index]
|
|
374
|
+
logger.warning(f"Switched to next model {new_model_option.model_name}")
|
|
375
|
+
return new_model_option
|
|
376
|
+
|
|
377
|
+
while True:
|
|
378
|
+
# Parse the current model option
|
|
379
|
+
cur_option = model_options[current_model_option_index]
|
|
380
|
+
provider, model_name, reasoning_effort = (
|
|
381
|
+
cur_option.provider,
|
|
382
|
+
cur_option.model_name,
|
|
383
|
+
cur_option.reasoning_effort,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
override_key = _api_key_overrides.get(provider)
|
|
387
|
+
|
|
388
|
+
client = await self._get_cached_client(provider, override_key)
|
|
389
|
+
single_output_getter = PROVIDERS[provider]["single_output_getter"]
|
|
390
|
+
single_streaming_output_getter = PROVIDERS[provider]["single_streaming_output_getter"]
|
|
391
|
+
|
|
392
|
+
# Get completions for uncached messages
|
|
393
|
+
outputs: list[LLMOutput] = await _parallelize_calls(
|
|
394
|
+
(
|
|
395
|
+
single_output_getter
|
|
396
|
+
if streaming_callback is None
|
|
397
|
+
else single_streaming_output_getter
|
|
398
|
+
),
|
|
399
|
+
streaming_callback,
|
|
400
|
+
validation_callback,
|
|
401
|
+
completion_callback,
|
|
402
|
+
client,
|
|
403
|
+
inputs,
|
|
404
|
+
model_name,
|
|
405
|
+
tools=tools,
|
|
406
|
+
tool_choice=tool_choice,
|
|
407
|
+
max_new_tokens=max_new_tokens,
|
|
408
|
+
temperature=temperature,
|
|
409
|
+
reasoning_effort=reasoning_effort,
|
|
410
|
+
logprobs=logprobs,
|
|
411
|
+
top_logprobs=top_logprobs,
|
|
412
|
+
timeout=timeout,
|
|
413
|
+
semaphore=self._semaphore,
|
|
414
|
+
cache=cache,
|
|
415
|
+
)
|
|
416
|
+
assert len(outputs) == len(inputs), "Number of outputs must match number of messages"
|
|
417
|
+
|
|
418
|
+
# Only count errors that should trigger model rotation (API errors, not validation/usage errors)
|
|
419
|
+
num_rotation_errors = sum(
|
|
420
|
+
1
|
|
421
|
+
for output in outputs
|
|
422
|
+
if output.did_error
|
|
423
|
+
and any(
|
|
424
|
+
not isinstance(e, (ValidationFailedException, DocentUsageLimitException))
|
|
425
|
+
for e in output.errors
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
if num_rotation_errors > 0:
|
|
429
|
+
logger.warning(f"{model_name}: {num_rotation_errors} API errors")
|
|
430
|
+
if not _rotate_model_option():
|
|
431
|
+
break
|
|
432
|
+
else:
|
|
433
|
+
break
|
|
434
|
+
|
|
435
|
+
return outputs
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
async def get_llm_completions_async(
|
|
439
|
+
inputs: list[MessagesInput],
|
|
440
|
+
model_options: list[ModelOption],
|
|
441
|
+
tools: list[ToolInfo] | None = None,
|
|
442
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
443
|
+
max_new_tokens: int = 1024,
|
|
444
|
+
temperature: float = 1.0,
|
|
445
|
+
logprobs: bool = False,
|
|
446
|
+
top_logprobs: int | None = None,
|
|
447
|
+
timeout: float = 120.0,
|
|
448
|
+
streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
449
|
+
validation_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
450
|
+
completion_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
451
|
+
use_cache: bool = False,
|
|
452
|
+
_api_key_overrides: dict[str, str] = dict(),
|
|
453
|
+
) -> list[LLMOutput]:
|
|
454
|
+
"""Convenience method for backward compatibility"""
|
|
455
|
+
|
|
456
|
+
svc = BaseLLMService()
|
|
457
|
+
return await svc.get_completions(
|
|
458
|
+
inputs=inputs,
|
|
459
|
+
model_options=model_options,
|
|
460
|
+
tools=tools,
|
|
461
|
+
tool_choice=tool_choice,
|
|
462
|
+
max_new_tokens=max_new_tokens,
|
|
463
|
+
temperature=temperature,
|
|
464
|
+
logprobs=logprobs,
|
|
465
|
+
top_logprobs=top_logprobs,
|
|
466
|
+
timeout=timeout,
|
|
467
|
+
streaming_callback=streaming_callback,
|
|
468
|
+
validation_callback=validation_callback,
|
|
469
|
+
completion_callback=completion_callback,
|
|
470
|
+
use_cache=use_cache,
|
|
471
|
+
_api_key_overrides=_api_key_overrides,
|
|
472
|
+
)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from docent._llm_util.data_models.llm_output import TokenType
|
|
8
|
+
from docent._log_util import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
Values are USD per million tokens
|
|
15
|
+
"""
|
|
16
|
+
ModelRate = dict[TokenType, float]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ModelInfo:
|
|
21
|
+
"""
|
|
22
|
+
Information about a model, including its rate and context window. Not to be confused with ModelOption.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Values are per 1,000,000 tokens
|
|
26
|
+
rate: Optional[ModelRate]
|
|
27
|
+
# Total context window tokens
|
|
28
|
+
context_window: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Note: some providers charge extra for long prompts/outputs. We don't account for this yet.
|
|
32
|
+
_REGISTRY: list[tuple[str, ModelInfo]] = [
|
|
33
|
+
(
|
|
34
|
+
"gpt-5-nano",
|
|
35
|
+
ModelInfo(rate={"input": 0.05, "output": 0.40}, context_window=400_000),
|
|
36
|
+
),
|
|
37
|
+
(
|
|
38
|
+
"gpt-5-mini",
|
|
39
|
+
ModelInfo(rate={"input": 0.25, "output": 2.0}, context_window=400_000),
|
|
40
|
+
),
|
|
41
|
+
(
|
|
42
|
+
"gpt-5",
|
|
43
|
+
ModelInfo(rate={"input": 1.25, "output": 10.0}, context_window=400_000),
|
|
44
|
+
),
|
|
45
|
+
(
|
|
46
|
+
"gpt-4o",
|
|
47
|
+
ModelInfo(rate={"input": 2.50, "output": 10.00}, context_window=100_000),
|
|
48
|
+
),
|
|
49
|
+
(
|
|
50
|
+
"o4-mini",
|
|
51
|
+
ModelInfo(rate={"input": 1.10, "output": 4.40}, context_window=100_000),
|
|
52
|
+
),
|
|
53
|
+
(
|
|
54
|
+
"claude-sonnet-4",
|
|
55
|
+
ModelInfo(rate={"input": 3.0, "output": 15.0}, context_window=200_000),
|
|
56
|
+
),
|
|
57
|
+
(
|
|
58
|
+
"claude-haiku-4-5",
|
|
59
|
+
ModelInfo(rate={"input": 1.0, "output": 5.0}, context_window=200_000),
|
|
60
|
+
),
|
|
61
|
+
(
|
|
62
|
+
"gemini-2.5-flash-lite",
|
|
63
|
+
ModelInfo(
|
|
64
|
+
rate={"input": 0.10, "output": 0.40},
|
|
65
|
+
context_window=1_000_000,
|
|
66
|
+
),
|
|
67
|
+
),
|
|
68
|
+
(
|
|
69
|
+
"gemini-2.5-flash",
|
|
70
|
+
ModelInfo(
|
|
71
|
+
rate={"input": 0.30, "output": 2.50},
|
|
72
|
+
context_window=1_000_000,
|
|
73
|
+
),
|
|
74
|
+
),
|
|
75
|
+
(
|
|
76
|
+
"gemini-2.5-pro",
|
|
77
|
+
ModelInfo(
|
|
78
|
+
rate={"input": 1.25, "output": 10.00},
|
|
79
|
+
context_window=1_000_000,
|
|
80
|
+
),
|
|
81
|
+
),
|
|
82
|
+
(
|
|
83
|
+
"grok-4-fast",
|
|
84
|
+
ModelInfo(
|
|
85
|
+
rate={"input": 0.20, "output": 0.50},
|
|
86
|
+
context_window=2_000_000,
|
|
87
|
+
),
|
|
88
|
+
),
|
|
89
|
+
(
|
|
90
|
+
"grok-4",
|
|
91
|
+
ModelInfo(
|
|
92
|
+
rate={"input": 3.0, "output": 15.0},
|
|
93
|
+
context_window=256_000,
|
|
94
|
+
),
|
|
95
|
+
),
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@lru_cache(maxsize=None)
|
|
100
|
+
def get_model_info(model_name: str) -> Optional[ModelInfo]:
|
|
101
|
+
for registry_model_name, info in _REGISTRY:
|
|
102
|
+
if registry_model_name in model_name:
|
|
103
|
+
return info
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_context_window(model_name: str) -> int:
|
|
108
|
+
info = get_model_info(model_name)
|
|
109
|
+
if info is None:
|
|
110
|
+
logger.warning(f"No context window found for model {model_name}")
|
|
111
|
+
return 100_000
|
|
112
|
+
return info.context_window
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_rates_for_model_name(model_name: str) -> Optional[ModelRate]:
|
|
116
|
+
info = get_model_info(model_name)
|
|
117
|
+
return info.rate if info is not None else None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def estimate_cost_cents(model_name: str, token_count: int, token_type: TokenType) -> float:
|
|
121
|
+
rate = get_rates_for_model_name(model_name)
|
|
122
|
+
if rate is None:
|
|
123
|
+
logger.warning(f"No rate found for model {model_name}")
|
|
124
|
+
return 0.0
|
|
125
|
+
usd_per_mtok = rate.get(token_type)
|
|
126
|
+
if usd_per_mtok is None:
|
|
127
|
+
logger.warning(f"No rate found for model {model_name} token type {token_type}")
|
|
128
|
+
return 0.0
|
|
129
|
+
cents_per_token = usd_per_mtok * 100 / 1_000_000.0
|
|
130
|
+
return token_count * cents_per_token
|
|
File without changes
|