docent-python 0.1.20a0__py3-none-any.whl → 0.1.22a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from typing import Optional
6
+
7
+ from docent._llm_util.data_models.llm_output import TokenType
8
+ from docent._log_util import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ """
14
+ Values are USD per million tokens
15
+ """
16
+ ModelRate = dict[TokenType, float]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ModelInfo:
21
+ """
22
+ Information about a model, including its rate and context window. Not to be confused with ModelOption.
23
+ """
24
+
25
+ # Values are per 1,000,000 tokens
26
+ rate: Optional[ModelRate]
27
+ # Total context window tokens
28
+ context_window: int
29
+
30
+
31
+ # Note: some providers charge extra for long prompts/outputs. We don't account for this yet.
32
+ _REGISTRY: list[tuple[str, ModelInfo]] = [
33
+ (
34
+ "gpt-5-nano",
35
+ ModelInfo(rate={"input": 0.05, "output": 0.40}, context_window=400_000),
36
+ ),
37
+ (
38
+ "gpt-5-mini",
39
+ ModelInfo(rate={"input": 0.25, "output": 2.0}, context_window=400_000),
40
+ ),
41
+ (
42
+ "gpt-5",
43
+ ModelInfo(rate={"input": 1.25, "output": 10.0}, context_window=400_000),
44
+ ),
45
+ (
46
+ "gpt-4o",
47
+ ModelInfo(rate={"input": 2.50, "output": 10.00}, context_window=100_000),
48
+ ),
49
+ (
50
+ "o4-mini",
51
+ ModelInfo(rate={"input": 1.10, "output": 4.40}, context_window=100_000),
52
+ ),
53
+ (
54
+ "claude-sonnet-4",
55
+ ModelInfo(rate={"input": 3.0, "output": 15.0}, context_window=200_000),
56
+ ),
57
+ (
58
+ "gemini-2.5-flash-lite",
59
+ ModelInfo(
60
+ rate={"input": 0.10, "output": 0.40},
61
+ context_window=1_000_000,
62
+ ),
63
+ ),
64
+ (
65
+ "gemini-2.5-flash",
66
+ ModelInfo(
67
+ rate={"input": 0.30, "output": 2.50},
68
+ context_window=1_000_000,
69
+ ),
70
+ ),
71
+ (
72
+ "gemini-2.5-pro",
73
+ ModelInfo(
74
+ rate={"input": 1.25, "output": 10.00},
75
+ context_window=1_000_000,
76
+ ),
77
+ ),
78
+ (
79
+ "grok-4-fast",
80
+ ModelInfo(
81
+ rate={"input": 0.20, "output": 0.50},
82
+ context_window=2_000_000,
83
+ ),
84
+ ),
85
+ (
86
+ "grok-4",
87
+ ModelInfo(
88
+ rate={"input": 3.0, "output": 15.0},
89
+ context_window=256_000,
90
+ ),
91
+ ),
92
+ ]
93
+
94
+
95
+ @lru_cache(maxsize=None)
96
+ def get_model_info(model_name: str) -> Optional[ModelInfo]:
97
+ for registry_model_name, info in _REGISTRY:
98
+ if registry_model_name in model_name:
99
+ return info
100
+ return None
101
+
102
+
103
+ def get_context_window(model_name: str) -> int:
104
+ info = get_model_info(model_name)
105
+ if info is None:
106
+ logger.warning(f"No context window found for model {model_name}")
107
+ return 100_000
108
+ return info.context_window
109
+
110
+
111
+ def get_rates_for_model_name(model_name: str) -> Optional[ModelRate]:
112
+ info = get_model_info(model_name)
113
+ return info.rate if info is not None else None
114
+
115
+
116
+ def estimate_cost_cents(model_name: str, token_count: int, token_type: TokenType) -> float:
117
+ rate = get_rates_for_model_name(model_name)
118
+ if rate is None:
119
+ logger.warning(f"No rate found for model {model_name}")
120
+ return 0.0
121
+ usd_per_mtok = rate.get(token_type)
122
+ if usd_per_mtok is None:
123
+ logger.warning(f"No rate found for model {model_name} token type {token_type}")
124
+ return 0.0
125
+ cents_per_token = usd_per_mtok * 100 / 1_000_000.0
126
+ return token_count * cents_per_token
@@ -0,0 +1,454 @@
1
+ """
2
+ At some point we'll want to do a refactor to support different types of provider/key swapping
3
+ due to different scenarios. However, this'll probably be a breaking change, which is why I'm
4
+ not doing it now.
5
+
6
+ - mengk
7
+ """
8
+
9
+ import traceback
10
+ from contextlib import nullcontext
11
+ from functools import partial
12
+ from typing import (
13
+ Any,
14
+ AsyncContextManager,
15
+ Literal,
16
+ Protocol,
17
+ Sequence,
18
+ cast,
19
+ runtime_checkable,
20
+ )
21
+
22
+ import anyio
23
+ from anyio.abc import TaskGroup
24
+ from tqdm.auto import tqdm
25
+
26
+ from docent._llm_util.data_models.exceptions import (
27
+ DocentUsageLimitException,
28
+ LLMException,
29
+ RateLimitException,
30
+ ValidationFailedException,
31
+ )
32
+ from docent._llm_util.data_models.llm_output import (
33
+ AsyncLLMOutputStreamingCallback,
34
+ AsyncSingleLLMOutputStreamingCallback,
35
+ LLMOutput,
36
+ )
37
+ from docent._llm_util.llm_cache import LLMCache
38
+ from docent._llm_util.providers.preference_types import ModelOption
39
+ from docent._llm_util.providers.provider_registry import (
40
+ PROVIDERS,
41
+ SingleOutputGetter,
42
+ SingleStreamingOutputGetter,
43
+ )
44
+ from docent._log_util import get_logger
45
+ from docent.data_models.chat import ChatMessage, ToolInfo, parse_chat_message
46
+
47
+ MAX_VALIDATION_ATTEMPTS = 3
48
+
49
+ logger = get_logger(__name__)
50
+
51
+
52
+ @runtime_checkable
53
+ class MessageResolver(Protocol):
54
+ def __call__(self) -> list[ChatMessage | dict[str, Any]]: ...
55
+
56
+
57
+ MessagesInput = Sequence[ChatMessage | dict[str, Any]] | MessageResolver
58
+
59
+
60
+ def _resolve_messages_input(messages_input: MessagesInput) -> list[ChatMessage]:
61
+ raw_messages = (
62
+ messages_input() if isinstance(messages_input, MessageResolver) else messages_input
63
+ )
64
+ return [parse_chat_message(msg) for msg in raw_messages]
65
+
66
+
67
+ def _get_single_streaming_callback(
68
+ batch_index: int,
69
+ streaming_callback: AsyncLLMOutputStreamingCallback,
70
+ ) -> AsyncSingleLLMOutputStreamingCallback:
71
+ async def single_streaming_callback(llm_output: LLMOutput):
72
+ await streaming_callback(batch_index, llm_output)
73
+
74
+ return single_streaming_callback
75
+
76
+
77
+ async def _parallelize_calls(
78
+ single_output_getter: SingleOutputGetter | SingleStreamingOutputGetter,
79
+ streaming_callback: AsyncLLMOutputStreamingCallback | None,
80
+ validation_callback: AsyncLLMOutputStreamingCallback | None,
81
+ completion_callback: AsyncLLMOutputStreamingCallback | None,
82
+ # Arguments for the individual completion getter
83
+ client: Any,
84
+ inputs: list[MessagesInput],
85
+ model_name: str,
86
+ tools: list[ToolInfo] | None,
87
+ tool_choice: Literal["auto", "required"] | None,
88
+ max_new_tokens: int,
89
+ temperature: float,
90
+ reasoning_effort: Literal["low", "medium", "high"] | None,
91
+ logprobs: bool,
92
+ top_logprobs: int | None,
93
+ timeout: float,
94
+ semaphore: AsyncContextManager[anyio.Semaphore] | None,
95
+ # use_tqdm: bool,
96
+ cache: LLMCache | None = None,
97
+ ):
98
+ base_func = partial(
99
+ single_output_getter,
100
+ client=client,
101
+ model_name=model_name,
102
+ tools=tools,
103
+ tool_choice=tool_choice,
104
+ max_new_tokens=max_new_tokens,
105
+ temperature=temperature,
106
+ reasoning_effort=reasoning_effort,
107
+ logprobs=logprobs,
108
+ top_logprobs=top_logprobs,
109
+ timeout=timeout,
110
+ )
111
+
112
+ responses: list[LLMOutput | None] = [None for _ in inputs]
113
+ pbar = (
114
+ tqdm(
115
+ total=len(inputs),
116
+ desc=f"Calling {model_name} (reasoning_effort={reasoning_effort}) API",
117
+ )
118
+ if len(inputs) > 1
119
+ else None
120
+ )
121
+
122
+ # Save resolved messages to avoid multiple resolutions
123
+ resolved_messages: list[list[ChatMessage] | None] = [None] * len(inputs)
124
+
125
+ cancelled_due_to_usage_limit: bool = False
126
+
127
+ async def _limited_task(i: int, cur_input: MessagesInput, tg: TaskGroup):
128
+ nonlocal responses, pbar, resolved_messages, cancelled_due_to_usage_limit
129
+
130
+ async with semaphore or nullcontext():
131
+ messages = _resolve_messages_input(cur_input)
132
+ resolved_messages[i] = messages
133
+
134
+ retry_count = 0
135
+ result = None
136
+
137
+ # Check if there's a cached result
138
+ cached_result = (
139
+ cache.get(
140
+ messages,
141
+ model_name,
142
+ tools=tools,
143
+ tool_choice=tool_choice,
144
+ reasoning_effort=reasoning_effort,
145
+ temperature=temperature,
146
+ logprobs=logprobs,
147
+ top_logprobs=top_logprobs,
148
+ )
149
+ if cache is not None
150
+ else None
151
+ )
152
+ if cached_result is not None:
153
+ result = cached_result
154
+ if streaming_callback is not None:
155
+ await streaming_callback(i, result)
156
+ else:
157
+ while retry_count < MAX_VALIDATION_ATTEMPTS:
158
+ try:
159
+ if streaming_callback is None:
160
+ result = await base_func(client=client, messages=messages)
161
+ else:
162
+ result = await base_func(
163
+ client=client,
164
+ streaming_callback=_get_single_streaming_callback(
165
+ i, streaming_callback
166
+ ),
167
+ messages=messages,
168
+ )
169
+
170
+ # Validate if validation callback provided and result is successful
171
+ if validation_callback and not result.did_error:
172
+ await validation_callback(i, result)
173
+
174
+ break
175
+ except ValidationFailedException as e:
176
+ retry_count += 1
177
+ logger.warning(
178
+ f"Validation failed for {model_name} after {retry_count} attempts: {e}"
179
+ )
180
+ if retry_count >= MAX_VALIDATION_ATTEMPTS:
181
+ logger.error(
182
+ f"Validation failed for {model_name} after {MAX_VALIDATION_ATTEMPTS} attempts: {e}"
183
+ )
184
+ result = LLMOutput(
185
+ model=model_name,
186
+ completions=[],
187
+ errors=[e],
188
+ )
189
+ break
190
+ except DocentUsageLimitException as e:
191
+ result = LLMOutput(
192
+ model=model_name,
193
+ completions=[],
194
+ errors=[], # Usage limit exceptions will be added to all results later if cancelled_due_to_usage_limit
195
+ )
196
+ cancelled_due_to_usage_limit = True
197
+ tg.cancel_scope.cancel()
198
+ break
199
+ except Exception as e:
200
+ if not isinstance(e, LLMException):
201
+ logger.warning(
202
+ f"LLM call raised an exception that is not an LLMException: {e}"
203
+ )
204
+ llm_exception = LLMException(e)
205
+ llm_exception.__cause__ = e
206
+ else:
207
+ llm_exception = e
208
+
209
+ error_message = f"Call to {model_name} failed even with backoff: {e.__class__.__name__}."
210
+
211
+ if not isinstance(e, RateLimitException):
212
+ error_message += f" Failure traceback:\n{traceback.format_exc()}"
213
+ logger.error(error_message)
214
+
215
+ result = LLMOutput(
216
+ model=model_name,
217
+ completions=[],
218
+ errors=[llm_exception],
219
+ )
220
+ break
221
+
222
+ # Always call completion callback with final result (success or error)
223
+ if completion_callback and result is not None:
224
+ try:
225
+ await completion_callback(i, result)
226
+ # LLMService uses this callback to record cost, and may throw an error if we just exceeded limit
227
+ except DocentUsageLimitException as e:
228
+ result.errors.append(e)
229
+ cancelled_due_to_usage_limit = True
230
+ tg.cancel_scope.cancel()
231
+
232
+ responses[i] = result
233
+ if pbar is not None:
234
+ pbar.update(1)
235
+ if pbar is None or pbar.n == pbar.total:
236
+ tg.cancel_scope.cancel()
237
+
238
+ def _cache_responses():
239
+ nonlocal responses, cache
240
+
241
+ if cache is not None:
242
+ indices = [
243
+ i
244
+ for i, response in enumerate(responses)
245
+ if resolved_messages[i] is not None
246
+ and response is not None
247
+ and not response.did_error
248
+ ]
249
+ cache.set_batch(
250
+ # We already checked that each index has a resolved messages list
251
+ [cast(list[ChatMessage], resolved_messages[i]) for i in indices],
252
+ model_name,
253
+ # We already checked that each index corresponds to an LLMOutput object
254
+ [cast(LLMOutput, responses[i]) for i in indices],
255
+ tools=tools,
256
+ tool_choice=tool_choice,
257
+ reasoning_effort=reasoning_effort,
258
+ temperature=temperature,
259
+ logprobs=logprobs,
260
+ top_logprobs=top_logprobs,
261
+ )
262
+ return len(indices)
263
+ else:
264
+ return 0
265
+
266
+ # Get all results concurrently
267
+ try:
268
+ async with anyio.create_task_group() as tg:
269
+ # Start all the individual tasks
270
+ for i, cur_input in enumerate(inputs):
271
+ tg.start_soon(_limited_task, i, cur_input, tg)
272
+
273
+ # Cache what we have so far if something got cancelled
274
+ except anyio.get_cancelled_exc_class():
275
+ num_cached = _cache_responses()
276
+ logger.info(
277
+ f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls; cached {num_cached} completed responses"
278
+ )
279
+ raise
280
+
281
+ if cancelled_due_to_usage_limit:
282
+ for i in range(len(responses)):
283
+ if responses[i] is None:
284
+ responses[i] = LLMOutput(
285
+ model=model_name,
286
+ completions=[],
287
+ errors=[DocentUsageLimitException()],
288
+ )
289
+ else:
290
+ responses[i].errors.append(DocentUsageLimitException())
291
+
292
+ # Cache results if available
293
+ _cache_responses()
294
+
295
+ # At this point, all indices should have a result
296
+ assert all(
297
+ isinstance(r, LLMOutput) for r in responses
298
+ ), "Some indices were never set to an LLMOutput, which should never happen"
299
+
300
+ return cast(list[LLMOutput], responses)
301
+
302
+
303
+ class LLMManager:
304
+ def __init__(
305
+ self,
306
+ model_options: list[ModelOption],
307
+ api_key_overrides: dict[str, str] | None = None,
308
+ use_cache: bool = False,
309
+ ):
310
+ # TODO(mengk): make this more robust, possibly move to a NoSQL database or something
311
+ try:
312
+ self.cache = LLMCache() if use_cache else None
313
+ except ValueError as e:
314
+ logger.warning(f"Disabling LLM cache due to init error: {e}")
315
+ self.cache = None
316
+
317
+ self.model_options = model_options
318
+ self.current_model_option_index = 0
319
+ self.api_key_overrides = api_key_overrides or {}
320
+
321
+ async def get_completions(
322
+ self,
323
+ inputs: list[MessagesInput],
324
+ tools: list[ToolInfo] | None = None,
325
+ tool_choice: Literal["auto", "required"] | None = None,
326
+ max_new_tokens: int = 32,
327
+ temperature: float = 1.0,
328
+ logprobs: bool = False,
329
+ top_logprobs: int | None = None,
330
+ max_concurrency: int | None = None,
331
+ timeout: float = 5.0,
332
+ streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
333
+ validation_callback: AsyncLLMOutputStreamingCallback | None = None,
334
+ completion_callback: AsyncLLMOutputStreamingCallback | None = None,
335
+ ) -> list[LLMOutput]:
336
+ while True:
337
+ # Parse the current model option
338
+ cur_option = self.model_options[self.current_model_option_index]
339
+ provider, model_name, reasoning_effort = (
340
+ cur_option.provider,
341
+ cur_option.model_name,
342
+ cur_option.reasoning_effort,
343
+ )
344
+
345
+ override_key = self.api_key_overrides.get(provider)
346
+
347
+ client = PROVIDERS[provider]["async_client_getter"](override_key)
348
+ single_output_getter = PROVIDERS[provider]["single_output_getter"]
349
+ single_streaming_output_getter = PROVIDERS[provider]["single_streaming_output_getter"]
350
+
351
+ # Get completions for uncached messages
352
+ outputs: list[LLMOutput] = await _parallelize_calls(
353
+ (
354
+ single_output_getter
355
+ if streaming_callback is None
356
+ else single_streaming_output_getter
357
+ ),
358
+ streaming_callback,
359
+ validation_callback,
360
+ completion_callback,
361
+ client,
362
+ inputs,
363
+ model_name,
364
+ tools=tools,
365
+ tool_choice=tool_choice,
366
+ max_new_tokens=max_new_tokens,
367
+ temperature=temperature,
368
+ reasoning_effort=reasoning_effort,
369
+ logprobs=logprobs,
370
+ top_logprobs=top_logprobs,
371
+ timeout=timeout,
372
+ semaphore=(
373
+ anyio.Semaphore(max_concurrency) if max_concurrency is not None else None
374
+ ),
375
+ cache=self.cache,
376
+ )
377
+ assert len(outputs) == len(inputs), "Number of outputs must match number of messages"
378
+
379
+ # Only count errors that should trigger model rotation (API errors, not validation/usage errors)
380
+ num_rotation_errors = sum(
381
+ 1
382
+ for output in outputs
383
+ if output.did_error
384
+ and any(
385
+ not isinstance(e, (ValidationFailedException, DocentUsageLimitException))
386
+ for e in output.errors
387
+ )
388
+ )
389
+ if num_rotation_errors > 0:
390
+ logger.warning(f"{model_name}: {num_rotation_errors} API errors")
391
+ if not self._rotate_model_option():
392
+ break
393
+ else:
394
+ break
395
+
396
+ return outputs
397
+
398
+ def _rotate_model_option(self) -> ModelOption | None:
399
+ self.current_model_option_index += 1
400
+ if self.current_model_option_index >= len(self.model_options):
401
+ logger.error("All model options are exhausted")
402
+ return None
403
+
404
+ new_model_option = self.model_options[self.current_model_option_index]
405
+ logger.warning(f"Switched to next model {new_model_option.model_name}")
406
+ return new_model_option
407
+
408
+
409
+ async def get_llm_completions_async(
410
+ inputs: list[MessagesInput],
411
+ model_options: list[ModelOption],
412
+ tools: list[ToolInfo] | None = None,
413
+ tool_choice: Literal["auto", "required"] | None = None,
414
+ max_new_tokens: int = 1024,
415
+ temperature: float = 1.0,
416
+ logprobs: bool = False,
417
+ top_logprobs: int | None = None,
418
+ max_concurrency: int = 100,
419
+ timeout: float = 120.0,
420
+ streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
421
+ validation_callback: AsyncLLMOutputStreamingCallback | None = None,
422
+ completion_callback: AsyncLLMOutputStreamingCallback | None = None,
423
+ use_cache: bool = False,
424
+ api_key_overrides: dict[str, str] | None = None,
425
+ ) -> list[LLMOutput]:
426
+ # We don't support logprobs for Anthropic yet
427
+ if logprobs:
428
+ for model_option in model_options:
429
+ if model_option.provider == "anthropic":
430
+ raise ValueError(
431
+ f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
432
+ )
433
+
434
+ # Create the LLM manager
435
+ llm_manager = LLMManager(
436
+ model_options=model_options,
437
+ api_key_overrides=api_key_overrides,
438
+ use_cache=use_cache,
439
+ )
440
+
441
+ return await llm_manager.get_completions(
442
+ inputs,
443
+ tools=tools,
444
+ tool_choice=tool_choice,
445
+ max_new_tokens=max_new_tokens,
446
+ temperature=temperature,
447
+ logprobs=logprobs,
448
+ top_logprobs=top_logprobs,
449
+ max_concurrency=max_concurrency,
450
+ timeout=timeout,
451
+ streaming_callback=streaming_callback,
452
+ validation_callback=validation_callback,
453
+ completion_callback=completion_callback,
454
+ )
File without changes