docent-python 0.1.20a0__py3-none-any.whl → 0.1.22a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +320 -0
- docent/_llm_util/data_models/simple_svc.py +79 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/model_registry.py +126 -0
- docent/_llm_util/prod_llms.py +454 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/transcript.py +2 -0
- docent/judges/__init__.py +21 -0
- docent/judges/impl.py +232 -0
- docent/judges/types.py +240 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +84 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +95 -0
- docent/judges/util/voting.py +114 -0
- docent/trace.py +19 -1
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/METADATA +7 -1
- docent_python-0.1.22a0.dist-info/RECORD +58 -0
- docent_python-0.1.20a0.dist-info/RECORD +0 -34
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.20a0.dist-info → docent_python-0.1.22a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from docent._llm_util.data_models.llm_output import TokenType
|
|
8
|
+
from docent._log_util import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
Values are USD per million tokens
|
|
15
|
+
"""
|
|
16
|
+
ModelRate = dict[TokenType, float]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ModelInfo:
|
|
21
|
+
"""
|
|
22
|
+
Information about a model, including its rate and context window. Not to be confused with ModelOption.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Values are per 1,000,000 tokens
|
|
26
|
+
rate: Optional[ModelRate]
|
|
27
|
+
# Total context window tokens
|
|
28
|
+
context_window: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Note: some providers charge extra for long prompts/outputs. We don't account for this yet.
|
|
32
|
+
_REGISTRY: list[tuple[str, ModelInfo]] = [
|
|
33
|
+
(
|
|
34
|
+
"gpt-5-nano",
|
|
35
|
+
ModelInfo(rate={"input": 0.05, "output": 0.40}, context_window=400_000),
|
|
36
|
+
),
|
|
37
|
+
(
|
|
38
|
+
"gpt-5-mini",
|
|
39
|
+
ModelInfo(rate={"input": 0.25, "output": 2.0}, context_window=400_000),
|
|
40
|
+
),
|
|
41
|
+
(
|
|
42
|
+
"gpt-5",
|
|
43
|
+
ModelInfo(rate={"input": 1.25, "output": 10.0}, context_window=400_000),
|
|
44
|
+
),
|
|
45
|
+
(
|
|
46
|
+
"gpt-4o",
|
|
47
|
+
ModelInfo(rate={"input": 2.50, "output": 10.00}, context_window=100_000),
|
|
48
|
+
),
|
|
49
|
+
(
|
|
50
|
+
"o4-mini",
|
|
51
|
+
ModelInfo(rate={"input": 1.10, "output": 4.40}, context_window=100_000),
|
|
52
|
+
),
|
|
53
|
+
(
|
|
54
|
+
"claude-sonnet-4",
|
|
55
|
+
ModelInfo(rate={"input": 3.0, "output": 15.0}, context_window=200_000),
|
|
56
|
+
),
|
|
57
|
+
(
|
|
58
|
+
"gemini-2.5-flash-lite",
|
|
59
|
+
ModelInfo(
|
|
60
|
+
rate={"input": 0.10, "output": 0.40},
|
|
61
|
+
context_window=1_000_000,
|
|
62
|
+
),
|
|
63
|
+
),
|
|
64
|
+
(
|
|
65
|
+
"gemini-2.5-flash",
|
|
66
|
+
ModelInfo(
|
|
67
|
+
rate={"input": 0.30, "output": 2.50},
|
|
68
|
+
context_window=1_000_000,
|
|
69
|
+
),
|
|
70
|
+
),
|
|
71
|
+
(
|
|
72
|
+
"gemini-2.5-pro",
|
|
73
|
+
ModelInfo(
|
|
74
|
+
rate={"input": 1.25, "output": 10.00},
|
|
75
|
+
context_window=1_000_000,
|
|
76
|
+
),
|
|
77
|
+
),
|
|
78
|
+
(
|
|
79
|
+
"grok-4-fast",
|
|
80
|
+
ModelInfo(
|
|
81
|
+
rate={"input": 0.20, "output": 0.50},
|
|
82
|
+
context_window=2_000_000,
|
|
83
|
+
),
|
|
84
|
+
),
|
|
85
|
+
(
|
|
86
|
+
"grok-4",
|
|
87
|
+
ModelInfo(
|
|
88
|
+
rate={"input": 3.0, "output": 15.0},
|
|
89
|
+
context_window=256_000,
|
|
90
|
+
),
|
|
91
|
+
),
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@lru_cache(maxsize=None)
|
|
96
|
+
def get_model_info(model_name: str) -> Optional[ModelInfo]:
|
|
97
|
+
for registry_model_name, info in _REGISTRY:
|
|
98
|
+
if registry_model_name in model_name:
|
|
99
|
+
return info
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_context_window(model_name: str) -> int:
|
|
104
|
+
info = get_model_info(model_name)
|
|
105
|
+
if info is None:
|
|
106
|
+
logger.warning(f"No context window found for model {model_name}")
|
|
107
|
+
return 100_000
|
|
108
|
+
return info.context_window
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_rates_for_model_name(model_name: str) -> Optional[ModelRate]:
|
|
112
|
+
info = get_model_info(model_name)
|
|
113
|
+
return info.rate if info is not None else None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def estimate_cost_cents(model_name: str, token_count: int, token_type: TokenType) -> float:
|
|
117
|
+
rate = get_rates_for_model_name(model_name)
|
|
118
|
+
if rate is None:
|
|
119
|
+
logger.warning(f"No rate found for model {model_name}")
|
|
120
|
+
return 0.0
|
|
121
|
+
usd_per_mtok = rate.get(token_type)
|
|
122
|
+
if usd_per_mtok is None:
|
|
123
|
+
logger.warning(f"No rate found for model {model_name} token type {token_type}")
|
|
124
|
+
return 0.0
|
|
125
|
+
cents_per_token = usd_per_mtok * 100 / 1_000_000.0
|
|
126
|
+
return token_count * cents_per_token
|
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
"""
|
|
2
|
+
At some point we'll want to do a refactor to support different types of provider/key swapping
|
|
3
|
+
due to different scenarios. However, this'll probably be a breaking change, which is why I'm
|
|
4
|
+
not doing it now.
|
|
5
|
+
|
|
6
|
+
- mengk
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import traceback
|
|
10
|
+
from contextlib import nullcontext
|
|
11
|
+
from functools import partial
|
|
12
|
+
from typing import (
|
|
13
|
+
Any,
|
|
14
|
+
AsyncContextManager,
|
|
15
|
+
Literal,
|
|
16
|
+
Protocol,
|
|
17
|
+
Sequence,
|
|
18
|
+
cast,
|
|
19
|
+
runtime_checkable,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
import anyio
|
|
23
|
+
from anyio.abc import TaskGroup
|
|
24
|
+
from tqdm.auto import tqdm
|
|
25
|
+
|
|
26
|
+
from docent._llm_util.data_models.exceptions import (
|
|
27
|
+
DocentUsageLimitException,
|
|
28
|
+
LLMException,
|
|
29
|
+
RateLimitException,
|
|
30
|
+
ValidationFailedException,
|
|
31
|
+
)
|
|
32
|
+
from docent._llm_util.data_models.llm_output import (
|
|
33
|
+
AsyncLLMOutputStreamingCallback,
|
|
34
|
+
AsyncSingleLLMOutputStreamingCallback,
|
|
35
|
+
LLMOutput,
|
|
36
|
+
)
|
|
37
|
+
from docent._llm_util.llm_cache import LLMCache
|
|
38
|
+
from docent._llm_util.providers.preference_types import ModelOption
|
|
39
|
+
from docent._llm_util.providers.provider_registry import (
|
|
40
|
+
PROVIDERS,
|
|
41
|
+
SingleOutputGetter,
|
|
42
|
+
SingleStreamingOutputGetter,
|
|
43
|
+
)
|
|
44
|
+
from docent._log_util import get_logger
|
|
45
|
+
from docent.data_models.chat import ChatMessage, ToolInfo, parse_chat_message
|
|
46
|
+
|
|
47
|
+
MAX_VALIDATION_ATTEMPTS = 3
|
|
48
|
+
|
|
49
|
+
logger = get_logger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@runtime_checkable
|
|
53
|
+
class MessageResolver(Protocol):
|
|
54
|
+
def __call__(self) -> list[ChatMessage | dict[str, Any]]: ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
MessagesInput = Sequence[ChatMessage | dict[str, Any]] | MessageResolver
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _resolve_messages_input(messages_input: MessagesInput) -> list[ChatMessage]:
|
|
61
|
+
raw_messages = (
|
|
62
|
+
messages_input() if isinstance(messages_input, MessageResolver) else messages_input
|
|
63
|
+
)
|
|
64
|
+
return [parse_chat_message(msg) for msg in raw_messages]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_single_streaming_callback(
|
|
68
|
+
batch_index: int,
|
|
69
|
+
streaming_callback: AsyncLLMOutputStreamingCallback,
|
|
70
|
+
) -> AsyncSingleLLMOutputStreamingCallback:
|
|
71
|
+
async def single_streaming_callback(llm_output: LLMOutput):
|
|
72
|
+
await streaming_callback(batch_index, llm_output)
|
|
73
|
+
|
|
74
|
+
return single_streaming_callback
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def _parallelize_calls(
|
|
78
|
+
single_output_getter: SingleOutputGetter | SingleStreamingOutputGetter,
|
|
79
|
+
streaming_callback: AsyncLLMOutputStreamingCallback | None,
|
|
80
|
+
validation_callback: AsyncLLMOutputStreamingCallback | None,
|
|
81
|
+
completion_callback: AsyncLLMOutputStreamingCallback | None,
|
|
82
|
+
# Arguments for the individual completion getter
|
|
83
|
+
client: Any,
|
|
84
|
+
inputs: list[MessagesInput],
|
|
85
|
+
model_name: str,
|
|
86
|
+
tools: list[ToolInfo] | None,
|
|
87
|
+
tool_choice: Literal["auto", "required"] | None,
|
|
88
|
+
max_new_tokens: int,
|
|
89
|
+
temperature: float,
|
|
90
|
+
reasoning_effort: Literal["low", "medium", "high"] | None,
|
|
91
|
+
logprobs: bool,
|
|
92
|
+
top_logprobs: int | None,
|
|
93
|
+
timeout: float,
|
|
94
|
+
semaphore: AsyncContextManager[anyio.Semaphore] | None,
|
|
95
|
+
# use_tqdm: bool,
|
|
96
|
+
cache: LLMCache | None = None,
|
|
97
|
+
):
|
|
98
|
+
base_func = partial(
|
|
99
|
+
single_output_getter,
|
|
100
|
+
client=client,
|
|
101
|
+
model_name=model_name,
|
|
102
|
+
tools=tools,
|
|
103
|
+
tool_choice=tool_choice,
|
|
104
|
+
max_new_tokens=max_new_tokens,
|
|
105
|
+
temperature=temperature,
|
|
106
|
+
reasoning_effort=reasoning_effort,
|
|
107
|
+
logprobs=logprobs,
|
|
108
|
+
top_logprobs=top_logprobs,
|
|
109
|
+
timeout=timeout,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
responses: list[LLMOutput | None] = [None for _ in inputs]
|
|
113
|
+
pbar = (
|
|
114
|
+
tqdm(
|
|
115
|
+
total=len(inputs),
|
|
116
|
+
desc=f"Calling {model_name} (reasoning_effort={reasoning_effort}) API",
|
|
117
|
+
)
|
|
118
|
+
if len(inputs) > 1
|
|
119
|
+
else None
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Save resolved messages to avoid multiple resolutions
|
|
123
|
+
resolved_messages: list[list[ChatMessage] | None] = [None] * len(inputs)
|
|
124
|
+
|
|
125
|
+
cancelled_due_to_usage_limit: bool = False
|
|
126
|
+
|
|
127
|
+
async def _limited_task(i: int, cur_input: MessagesInput, tg: TaskGroup):
|
|
128
|
+
nonlocal responses, pbar, resolved_messages, cancelled_due_to_usage_limit
|
|
129
|
+
|
|
130
|
+
async with semaphore or nullcontext():
|
|
131
|
+
messages = _resolve_messages_input(cur_input)
|
|
132
|
+
resolved_messages[i] = messages
|
|
133
|
+
|
|
134
|
+
retry_count = 0
|
|
135
|
+
result = None
|
|
136
|
+
|
|
137
|
+
# Check if there's a cached result
|
|
138
|
+
cached_result = (
|
|
139
|
+
cache.get(
|
|
140
|
+
messages,
|
|
141
|
+
model_name,
|
|
142
|
+
tools=tools,
|
|
143
|
+
tool_choice=tool_choice,
|
|
144
|
+
reasoning_effort=reasoning_effort,
|
|
145
|
+
temperature=temperature,
|
|
146
|
+
logprobs=logprobs,
|
|
147
|
+
top_logprobs=top_logprobs,
|
|
148
|
+
)
|
|
149
|
+
if cache is not None
|
|
150
|
+
else None
|
|
151
|
+
)
|
|
152
|
+
if cached_result is not None:
|
|
153
|
+
result = cached_result
|
|
154
|
+
if streaming_callback is not None:
|
|
155
|
+
await streaming_callback(i, result)
|
|
156
|
+
else:
|
|
157
|
+
while retry_count < MAX_VALIDATION_ATTEMPTS:
|
|
158
|
+
try:
|
|
159
|
+
if streaming_callback is None:
|
|
160
|
+
result = await base_func(client=client, messages=messages)
|
|
161
|
+
else:
|
|
162
|
+
result = await base_func(
|
|
163
|
+
client=client,
|
|
164
|
+
streaming_callback=_get_single_streaming_callback(
|
|
165
|
+
i, streaming_callback
|
|
166
|
+
),
|
|
167
|
+
messages=messages,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Validate if validation callback provided and result is successful
|
|
171
|
+
if validation_callback and not result.did_error:
|
|
172
|
+
await validation_callback(i, result)
|
|
173
|
+
|
|
174
|
+
break
|
|
175
|
+
except ValidationFailedException as e:
|
|
176
|
+
retry_count += 1
|
|
177
|
+
logger.warning(
|
|
178
|
+
f"Validation failed for {model_name} after {retry_count} attempts: {e}"
|
|
179
|
+
)
|
|
180
|
+
if retry_count >= MAX_VALIDATION_ATTEMPTS:
|
|
181
|
+
logger.error(
|
|
182
|
+
f"Validation failed for {model_name} after {MAX_VALIDATION_ATTEMPTS} attempts: {e}"
|
|
183
|
+
)
|
|
184
|
+
result = LLMOutput(
|
|
185
|
+
model=model_name,
|
|
186
|
+
completions=[],
|
|
187
|
+
errors=[e],
|
|
188
|
+
)
|
|
189
|
+
break
|
|
190
|
+
except DocentUsageLimitException as e:
|
|
191
|
+
result = LLMOutput(
|
|
192
|
+
model=model_name,
|
|
193
|
+
completions=[],
|
|
194
|
+
errors=[], # Usage limit exceptions will be added to all results later if cancelled_due_to_usage_limit
|
|
195
|
+
)
|
|
196
|
+
cancelled_due_to_usage_limit = True
|
|
197
|
+
tg.cancel_scope.cancel()
|
|
198
|
+
break
|
|
199
|
+
except Exception as e:
|
|
200
|
+
if not isinstance(e, LLMException):
|
|
201
|
+
logger.warning(
|
|
202
|
+
f"LLM call raised an exception that is not an LLMException: {e}"
|
|
203
|
+
)
|
|
204
|
+
llm_exception = LLMException(e)
|
|
205
|
+
llm_exception.__cause__ = e
|
|
206
|
+
else:
|
|
207
|
+
llm_exception = e
|
|
208
|
+
|
|
209
|
+
error_message = f"Call to {model_name} failed even with backoff: {e.__class__.__name__}."
|
|
210
|
+
|
|
211
|
+
if not isinstance(e, RateLimitException):
|
|
212
|
+
error_message += f" Failure traceback:\n{traceback.format_exc()}"
|
|
213
|
+
logger.error(error_message)
|
|
214
|
+
|
|
215
|
+
result = LLMOutput(
|
|
216
|
+
model=model_name,
|
|
217
|
+
completions=[],
|
|
218
|
+
errors=[llm_exception],
|
|
219
|
+
)
|
|
220
|
+
break
|
|
221
|
+
|
|
222
|
+
# Always call completion callback with final result (success or error)
|
|
223
|
+
if completion_callback and result is not None:
|
|
224
|
+
try:
|
|
225
|
+
await completion_callback(i, result)
|
|
226
|
+
# LLMService uses this callback to record cost, and may throw an error if we just exceeded limit
|
|
227
|
+
except DocentUsageLimitException as e:
|
|
228
|
+
result.errors.append(e)
|
|
229
|
+
cancelled_due_to_usage_limit = True
|
|
230
|
+
tg.cancel_scope.cancel()
|
|
231
|
+
|
|
232
|
+
responses[i] = result
|
|
233
|
+
if pbar is not None:
|
|
234
|
+
pbar.update(1)
|
|
235
|
+
if pbar is None or pbar.n == pbar.total:
|
|
236
|
+
tg.cancel_scope.cancel()
|
|
237
|
+
|
|
238
|
+
def _cache_responses():
|
|
239
|
+
nonlocal responses, cache
|
|
240
|
+
|
|
241
|
+
if cache is not None:
|
|
242
|
+
indices = [
|
|
243
|
+
i
|
|
244
|
+
for i, response in enumerate(responses)
|
|
245
|
+
if resolved_messages[i] is not None
|
|
246
|
+
and response is not None
|
|
247
|
+
and not response.did_error
|
|
248
|
+
]
|
|
249
|
+
cache.set_batch(
|
|
250
|
+
# We already checked that each index has a resolved messages list
|
|
251
|
+
[cast(list[ChatMessage], resolved_messages[i]) for i in indices],
|
|
252
|
+
model_name,
|
|
253
|
+
# We already checked that each index corresponds to an LLMOutput object
|
|
254
|
+
[cast(LLMOutput, responses[i]) for i in indices],
|
|
255
|
+
tools=tools,
|
|
256
|
+
tool_choice=tool_choice,
|
|
257
|
+
reasoning_effort=reasoning_effort,
|
|
258
|
+
temperature=temperature,
|
|
259
|
+
logprobs=logprobs,
|
|
260
|
+
top_logprobs=top_logprobs,
|
|
261
|
+
)
|
|
262
|
+
return len(indices)
|
|
263
|
+
else:
|
|
264
|
+
return 0
|
|
265
|
+
|
|
266
|
+
# Get all results concurrently
|
|
267
|
+
try:
|
|
268
|
+
async with anyio.create_task_group() as tg:
|
|
269
|
+
# Start all the individual tasks
|
|
270
|
+
for i, cur_input in enumerate(inputs):
|
|
271
|
+
tg.start_soon(_limited_task, i, cur_input, tg)
|
|
272
|
+
|
|
273
|
+
# Cache what we have so far if something got cancelled
|
|
274
|
+
except anyio.get_cancelled_exc_class():
|
|
275
|
+
num_cached = _cache_responses()
|
|
276
|
+
logger.info(
|
|
277
|
+
f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls; cached {num_cached} completed responses"
|
|
278
|
+
)
|
|
279
|
+
raise
|
|
280
|
+
|
|
281
|
+
if cancelled_due_to_usage_limit:
|
|
282
|
+
for i in range(len(responses)):
|
|
283
|
+
if responses[i] is None:
|
|
284
|
+
responses[i] = LLMOutput(
|
|
285
|
+
model=model_name,
|
|
286
|
+
completions=[],
|
|
287
|
+
errors=[DocentUsageLimitException()],
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
responses[i].errors.append(DocentUsageLimitException())
|
|
291
|
+
|
|
292
|
+
# Cache results if available
|
|
293
|
+
_cache_responses()
|
|
294
|
+
|
|
295
|
+
# At this point, all indices should have a result
|
|
296
|
+
assert all(
|
|
297
|
+
isinstance(r, LLMOutput) for r in responses
|
|
298
|
+
), "Some indices were never set to an LLMOutput, which should never happen"
|
|
299
|
+
|
|
300
|
+
return cast(list[LLMOutput], responses)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class LLMManager:
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
model_options: list[ModelOption],
|
|
307
|
+
api_key_overrides: dict[str, str] | None = None,
|
|
308
|
+
use_cache: bool = False,
|
|
309
|
+
):
|
|
310
|
+
# TODO(mengk): make this more robust, possibly move to a NoSQL database or something
|
|
311
|
+
try:
|
|
312
|
+
self.cache = LLMCache() if use_cache else None
|
|
313
|
+
except ValueError as e:
|
|
314
|
+
logger.warning(f"Disabling LLM cache due to init error: {e}")
|
|
315
|
+
self.cache = None
|
|
316
|
+
|
|
317
|
+
self.model_options = model_options
|
|
318
|
+
self.current_model_option_index = 0
|
|
319
|
+
self.api_key_overrides = api_key_overrides or {}
|
|
320
|
+
|
|
321
|
+
async def get_completions(
|
|
322
|
+
self,
|
|
323
|
+
inputs: list[MessagesInput],
|
|
324
|
+
tools: list[ToolInfo] | None = None,
|
|
325
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
326
|
+
max_new_tokens: int = 32,
|
|
327
|
+
temperature: float = 1.0,
|
|
328
|
+
logprobs: bool = False,
|
|
329
|
+
top_logprobs: int | None = None,
|
|
330
|
+
max_concurrency: int | None = None,
|
|
331
|
+
timeout: float = 5.0,
|
|
332
|
+
streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
333
|
+
validation_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
334
|
+
completion_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
335
|
+
) -> list[LLMOutput]:
|
|
336
|
+
while True:
|
|
337
|
+
# Parse the current model option
|
|
338
|
+
cur_option = self.model_options[self.current_model_option_index]
|
|
339
|
+
provider, model_name, reasoning_effort = (
|
|
340
|
+
cur_option.provider,
|
|
341
|
+
cur_option.model_name,
|
|
342
|
+
cur_option.reasoning_effort,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
override_key = self.api_key_overrides.get(provider)
|
|
346
|
+
|
|
347
|
+
client = PROVIDERS[provider]["async_client_getter"](override_key)
|
|
348
|
+
single_output_getter = PROVIDERS[provider]["single_output_getter"]
|
|
349
|
+
single_streaming_output_getter = PROVIDERS[provider]["single_streaming_output_getter"]
|
|
350
|
+
|
|
351
|
+
# Get completions for uncached messages
|
|
352
|
+
outputs: list[LLMOutput] = await _parallelize_calls(
|
|
353
|
+
(
|
|
354
|
+
single_output_getter
|
|
355
|
+
if streaming_callback is None
|
|
356
|
+
else single_streaming_output_getter
|
|
357
|
+
),
|
|
358
|
+
streaming_callback,
|
|
359
|
+
validation_callback,
|
|
360
|
+
completion_callback,
|
|
361
|
+
client,
|
|
362
|
+
inputs,
|
|
363
|
+
model_name,
|
|
364
|
+
tools=tools,
|
|
365
|
+
tool_choice=tool_choice,
|
|
366
|
+
max_new_tokens=max_new_tokens,
|
|
367
|
+
temperature=temperature,
|
|
368
|
+
reasoning_effort=reasoning_effort,
|
|
369
|
+
logprobs=logprobs,
|
|
370
|
+
top_logprobs=top_logprobs,
|
|
371
|
+
timeout=timeout,
|
|
372
|
+
semaphore=(
|
|
373
|
+
anyio.Semaphore(max_concurrency) if max_concurrency is not None else None
|
|
374
|
+
),
|
|
375
|
+
cache=self.cache,
|
|
376
|
+
)
|
|
377
|
+
assert len(outputs) == len(inputs), "Number of outputs must match number of messages"
|
|
378
|
+
|
|
379
|
+
# Only count errors that should trigger model rotation (API errors, not validation/usage errors)
|
|
380
|
+
num_rotation_errors = sum(
|
|
381
|
+
1
|
|
382
|
+
for output in outputs
|
|
383
|
+
if output.did_error
|
|
384
|
+
and any(
|
|
385
|
+
not isinstance(e, (ValidationFailedException, DocentUsageLimitException))
|
|
386
|
+
for e in output.errors
|
|
387
|
+
)
|
|
388
|
+
)
|
|
389
|
+
if num_rotation_errors > 0:
|
|
390
|
+
logger.warning(f"{model_name}: {num_rotation_errors} API errors")
|
|
391
|
+
if not self._rotate_model_option():
|
|
392
|
+
break
|
|
393
|
+
else:
|
|
394
|
+
break
|
|
395
|
+
|
|
396
|
+
return outputs
|
|
397
|
+
|
|
398
|
+
def _rotate_model_option(self) -> ModelOption | None:
|
|
399
|
+
self.current_model_option_index += 1
|
|
400
|
+
if self.current_model_option_index >= len(self.model_options):
|
|
401
|
+
logger.error("All model options are exhausted")
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
new_model_option = self.model_options[self.current_model_option_index]
|
|
405
|
+
logger.warning(f"Switched to next model {new_model_option.model_name}")
|
|
406
|
+
return new_model_option
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
async def get_llm_completions_async(
|
|
410
|
+
inputs: list[MessagesInput],
|
|
411
|
+
model_options: list[ModelOption],
|
|
412
|
+
tools: list[ToolInfo] | None = None,
|
|
413
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
414
|
+
max_new_tokens: int = 1024,
|
|
415
|
+
temperature: float = 1.0,
|
|
416
|
+
logprobs: bool = False,
|
|
417
|
+
top_logprobs: int | None = None,
|
|
418
|
+
max_concurrency: int = 100,
|
|
419
|
+
timeout: float = 120.0,
|
|
420
|
+
streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
421
|
+
validation_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
422
|
+
completion_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
423
|
+
use_cache: bool = False,
|
|
424
|
+
api_key_overrides: dict[str, str] | None = None,
|
|
425
|
+
) -> list[LLMOutput]:
|
|
426
|
+
# We don't support logprobs for Anthropic yet
|
|
427
|
+
if logprobs:
|
|
428
|
+
for model_option in model_options:
|
|
429
|
+
if model_option.provider == "anthropic":
|
|
430
|
+
raise ValueError(
|
|
431
|
+
f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Create the LLM manager
|
|
435
|
+
llm_manager = LLMManager(
|
|
436
|
+
model_options=model_options,
|
|
437
|
+
api_key_overrides=api_key_overrides,
|
|
438
|
+
use_cache=use_cache,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
return await llm_manager.get_completions(
|
|
442
|
+
inputs,
|
|
443
|
+
tools=tools,
|
|
444
|
+
tool_choice=tool_choice,
|
|
445
|
+
max_new_tokens=max_new_tokens,
|
|
446
|
+
temperature=temperature,
|
|
447
|
+
logprobs=logprobs,
|
|
448
|
+
top_logprobs=top_logprobs,
|
|
449
|
+
max_concurrency=max_concurrency,
|
|
450
|
+
timeout=timeout,
|
|
451
|
+
streaming_callback=streaming_callback,
|
|
452
|
+
validation_callback=validation_callback,
|
|
453
|
+
completion_callback=completion_callback,
|
|
454
|
+
)
|
|
File without changes
|