docent-python 0.1.17a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (45) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +331 -0
  5. docent/_llm_util/llm_cache.py +193 -0
  6. docent/_llm_util/llm_svc.py +472 -0
  7. docent/_llm_util/model_registry.py +130 -0
  8. docent/_llm_util/providers/__init__.py +0 -0
  9. docent/_llm_util/providers/anthropic.py +537 -0
  10. docent/_llm_util/providers/common.py +41 -0
  11. docent/_llm_util/providers/google.py +530 -0
  12. docent/_llm_util/providers/openai.py +745 -0
  13. docent/_llm_util/providers/openrouter.py +375 -0
  14. docent/_llm_util/providers/preference_types.py +104 -0
  15. docent/_llm_util/providers/provider_registry.py +164 -0
  16. docent/data_models/__init__.py +2 -0
  17. docent/data_models/agent_run.py +6 -5
  18. docent/data_models/chat/__init__.py +6 -1
  19. docent/data_models/citation.py +103 -22
  20. docent/data_models/judge.py +19 -0
  21. docent/data_models/metadata_util.py +16 -0
  22. docent/data_models/remove_invalid_citation_ranges.py +23 -10
  23. docent/data_models/transcript.py +20 -16
  24. docent/data_models/util.py +170 -0
  25. docent/judges/__init__.py +23 -0
  26. docent/judges/analysis.py +77 -0
  27. docent/judges/impl.py +587 -0
  28. docent/judges/runner.py +129 -0
  29. docent/judges/stats.py +205 -0
  30. docent/judges/types.py +311 -0
  31. docent/judges/util/forgiving_json.py +108 -0
  32. docent/judges/util/meta_schema.json +86 -0
  33. docent/judges/util/meta_schema.py +29 -0
  34. docent/judges/util/parse_output.py +87 -0
  35. docent/judges/util/voting.py +139 -0
  36. docent/sdk/agent_run_writer.py +62 -19
  37. docent/sdk/client.py +244 -23
  38. docent/trace.py +413 -90
  39. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
  40. docent_python-0.1.27a0.dist-info/RECORD +59 -0
  41. docent/data_models/metadata.py +0 -229
  42. docent/data_models/yaml_util.py +0 -12
  43. docent_python-0.1.17a0.dist-info/RECORD +0 -32
  44. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
  45. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,472 @@
1
+ import time
2
+ import traceback
3
+ from functools import partial
4
+ from typing import (
5
+ Any,
6
+ Literal,
7
+ Protocol,
8
+ Sequence,
9
+ cast,
10
+ runtime_checkable,
11
+ )
12
+
13
+ import anyio
14
+ from anyio import Lock, Semaphore
15
+ from anyio.abc import TaskGroup
16
+ from tqdm.auto import tqdm
17
+
18
+ from docent._llm_util.data_models.exceptions import (
19
+ DocentUsageLimitException,
20
+ LLMException,
21
+ RateLimitException,
22
+ ValidationFailedException,
23
+ )
24
+ from docent._llm_util.data_models.llm_output import (
25
+ AsyncLLMOutputStreamingCallback,
26
+ AsyncSingleLLMOutputStreamingCallback,
27
+ LLMOutput,
28
+ )
29
+ from docent._llm_util.llm_cache import LLMCache
30
+ from docent._llm_util.providers.preference_types import ModelOption
31
+ from docent._llm_util.providers.provider_registry import (
32
+ PROVIDERS,
33
+ SingleOutputGetter,
34
+ SingleStreamingOutputGetter,
35
+ )
36
+ from docent._log_util import get_logger
37
+ from docent.data_models.chat import ChatMessage, ToolInfo, parse_chat_message
38
+
39
+ logger = get_logger(__name__)
40
+
41
+ MAX_VALIDATION_ATTEMPTS = 3
42
+ DEFAULT_MAX_CONCURRENCY = 100
43
+ DEFAULT_SVC_MAX_CONCURRENCY = 100
44
+
45
+
46
+ @runtime_checkable
47
+ class MessageResolver(Protocol):
48
+ def __call__(self) -> list[ChatMessage | dict[str, Any]]: ...
49
+
50
+
51
+ MessagesInput = Sequence[ChatMessage | dict[str, Any]] | MessageResolver
52
+
53
+
54
+ def _resolve_messages_input(messages_input: MessagesInput) -> list[ChatMessage]:
55
+ raw_messages = (
56
+ messages_input() if isinstance(messages_input, MessageResolver) else messages_input
57
+ )
58
+ return [parse_chat_message(msg) for msg in raw_messages]
59
+
60
+
61
+ def _get_single_streaming_callback(
62
+ batch_index: int,
63
+ streaming_callback: AsyncLLMOutputStreamingCallback,
64
+ ) -> AsyncSingleLLMOutputStreamingCallback:
65
+ async def single_streaming_callback(llm_output: LLMOutput):
66
+ await streaming_callback(batch_index, llm_output)
67
+
68
+ return single_streaming_callback
69
+
70
+
71
+ async def _parallelize_calls(
72
+ single_output_getter: SingleOutputGetter | SingleStreamingOutputGetter,
73
+ streaming_callback: AsyncLLMOutputStreamingCallback | None,
74
+ validation_callback: AsyncLLMOutputStreamingCallback | None,
75
+ completion_callback: AsyncLLMOutputStreamingCallback | None,
76
+ # Arguments for the individual completion getter
77
+ client: Any,
78
+ inputs: Sequence[MessagesInput],
79
+ model_name: str,
80
+ tools: list[ToolInfo] | None,
81
+ tool_choice: Literal["auto", "required"] | None,
82
+ max_new_tokens: int,
83
+ temperature: float,
84
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None,
85
+ logprobs: bool,
86
+ top_logprobs: int | None,
87
+ timeout: float,
88
+ semaphore: Semaphore,
89
+ # use_tqdm: bool,
90
+ cache: LLMCache | None = None,
91
+ ):
92
+ base_func = partial(
93
+ single_output_getter,
94
+ client=client,
95
+ model_name=model_name,
96
+ tools=tools,
97
+ tool_choice=tool_choice,
98
+ max_new_tokens=max_new_tokens,
99
+ temperature=temperature,
100
+ reasoning_effort=reasoning_effort,
101
+ logprobs=logprobs,
102
+ top_logprobs=top_logprobs,
103
+ timeout=timeout,
104
+ )
105
+
106
+ responses: list[LLMOutput | None] = [None for _ in inputs]
107
+ pbar = (
108
+ tqdm(
109
+ total=len(inputs),
110
+ desc=f"Calling {model_name} (reasoning_effort={reasoning_effort}) API",
111
+ )
112
+ if len(inputs) > 1
113
+ else None
114
+ )
115
+
116
+ # Save resolved messages to avoid multiple resolutions
117
+ resolved_messages: list[list[ChatMessage] | None] = [None] * len(inputs)
118
+
119
+ # Not sure why the cast is necessary for the type checker
120
+ cancelled_due_to_usage_limit: bool = cast(bool, False)
121
+
122
+ async def _limited_task(i: int, cur_input: MessagesInput, tg: TaskGroup):
123
+ nonlocal responses, pbar, resolved_messages, cancelled_due_to_usage_limit
124
+
125
+ async with semaphore:
126
+ messages = _resolve_messages_input(cur_input)
127
+ resolved_messages[i] = messages
128
+
129
+ retry_count = 0
130
+ result = None
131
+ call_started_at: float | None = None
132
+
133
+ # Check if there's a cached result
134
+ cached_result = (
135
+ cache.get(
136
+ messages,
137
+ model_name,
138
+ tools=tools,
139
+ tool_choice=tool_choice,
140
+ reasoning_effort=reasoning_effort,
141
+ temperature=temperature,
142
+ logprobs=logprobs,
143
+ top_logprobs=top_logprobs,
144
+ )
145
+ if cache is not None
146
+ else None
147
+ )
148
+ if cached_result is not None:
149
+ result = cached_result
150
+ if streaming_callback is not None:
151
+ await streaming_callback(i, result)
152
+ else:
153
+ call_started_at = time.perf_counter()
154
+ while retry_count < MAX_VALIDATION_ATTEMPTS:
155
+ try:
156
+ if streaming_callback is None:
157
+ result = await base_func(client=client, messages=messages)
158
+ else:
159
+ result = await base_func(
160
+ client=client,
161
+ streaming_callback=_get_single_streaming_callback(
162
+ i, streaming_callback
163
+ ),
164
+ messages=messages,
165
+ )
166
+
167
+ # Validate if validation callback provided and result is successful
168
+ if validation_callback and not result.did_error:
169
+ await validation_callback(i, result)
170
+
171
+ break
172
+ except ValidationFailedException as e:
173
+ retry_count += 1
174
+ logger.warning(
175
+ f"Validation failed for {model_name} after {retry_count} attempts: {e}"
176
+ )
177
+ if retry_count >= MAX_VALIDATION_ATTEMPTS:
178
+ logger.error(
179
+ f"Validation failed for {model_name} after {retry_count} attempts. Original output: {e.failed_output}"
180
+ )
181
+ result = LLMOutput(
182
+ model=model_name,
183
+ completions=[],
184
+ errors=[e],
185
+ )
186
+ break
187
+ except DocentUsageLimitException as _:
188
+ result = LLMOutput(
189
+ model=model_name,
190
+ completions=[],
191
+ errors=[], # Usage limit exceptions will be added to all results later if cancelled_due_to_usage_limit
192
+ )
193
+ cancelled_due_to_usage_limit = True
194
+ tg.cancel_scope.cancel()
195
+ break
196
+ except Exception as e:
197
+ if not isinstance(e, LLMException):
198
+ logger.error(
199
+ f"LLM call raised an exception that is not an LLMException: {e}. Failure traceback:\n{traceback.format_exc()}"
200
+ )
201
+ llm_exception = LLMException(e)
202
+ llm_exception.__cause__ = e
203
+ else:
204
+ llm_exception = e
205
+
206
+ error_message = f"Call to {model_name} failed even with backoff: {e.__class__.__name__}."
207
+
208
+ if not isinstance(e, RateLimitException):
209
+ error_message += f" Failure traceback:\n{traceback.format_exc()}"
210
+ logger.error(error_message)
211
+
212
+ result = LLMOutput(
213
+ model=model_name,
214
+ completions=[],
215
+ errors=[llm_exception],
216
+ )
217
+ break
218
+
219
+ # Only store the elapsed time if we didn't hit the cache and the call was successful
220
+ if cached_result is None and result is not None and call_started_at is not None:
221
+ result.duration = time.perf_counter() - call_started_at
222
+
223
+ # Always call completion callback with final result (success or error)
224
+ if completion_callback and result is not None:
225
+ try:
226
+ await completion_callback(i, result)
227
+ # LLMService uses this callback to record cost, and may throw an error if we just exceeded limit
228
+ except DocentUsageLimitException as e:
229
+ result.errors.append(e)
230
+ cancelled_due_to_usage_limit = True
231
+ tg.cancel_scope.cancel()
232
+
233
+ responses[i] = result
234
+ if pbar is not None:
235
+ pbar.update(1)
236
+ if pbar is None or pbar.n == pbar.total:
237
+ tg.cancel_scope.cancel()
238
+
239
+ def _cache_responses():
240
+ nonlocal responses, cache
241
+
242
+ if cache is not None:
243
+ indices = [
244
+ i
245
+ for i, response in enumerate(responses)
246
+ if resolved_messages[i] is not None
247
+ and response is not None
248
+ and not response.did_error
249
+ ]
250
+ cache.set_batch(
251
+ # We already checked that each index has a resolved messages list
252
+ [cast(list[ChatMessage], resolved_messages[i]) for i in indices],
253
+ model_name,
254
+ # We already checked that each index corresponds to an LLMOutput object
255
+ [cast(LLMOutput, responses[i]) for i in indices],
256
+ tools=tools,
257
+ tool_choice=tool_choice,
258
+ reasoning_effort=reasoning_effort,
259
+ temperature=temperature,
260
+ logprobs=logprobs,
261
+ top_logprobs=top_logprobs,
262
+ )
263
+ return len(indices)
264
+ else:
265
+ return 0
266
+
267
+ # Get all results concurrently
268
+ try:
269
+ async with anyio.create_task_group() as tg:
270
+ # Start all the individual tasks
271
+ for i, cur_input in enumerate(inputs):
272
+ tg.start_soon(_limited_task, i, cur_input, tg)
273
+
274
+ # Cache what we have so far if something got cancelled
275
+ except anyio.get_cancelled_exc_class():
276
+ num_cached = _cache_responses()
277
+ if num_cached:
278
+ logger.info(
279
+ f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls, but cached {num_cached} completed responses"
280
+ )
281
+
282
+ # If the task was cancelled due to usage limit, set the response to a usage limit exception
283
+ if cancelled_due_to_usage_limit:
284
+ for i, response in enumerate(responses):
285
+ if response is None:
286
+ responses[i] = LLMOutput(
287
+ model=model_name,
288
+ completions=[],
289
+ errors=[DocentUsageLimitException()],
290
+ )
291
+ else:
292
+ response.errors.append(DocentUsageLimitException())
293
+
294
+ raise
295
+
296
+ # Cache results if available
297
+ _cache_responses()
298
+
299
+ # At this point, all indices should have a result
300
+ assert all(
301
+ isinstance(r, LLMOutput) for r in responses
302
+ ), "Some indices were never set to an LLMOutput, which should never happen"
303
+
304
+ return cast(list[LLMOutput], responses)
305
+
306
+
307
+ class BaseLLMService:
308
+ def __init__(self, max_concurrency: int = DEFAULT_SVC_MAX_CONCURRENCY):
309
+ self.max_concurrency, self._semaphore = max_concurrency, Semaphore(max_concurrency)
310
+ self._client_cache: dict[tuple[str, str | None], Any] = {} # (provider, api_key) -> client
311
+ self._client_cache_lock = Lock()
312
+
313
+ async def _get_cached_client(self, provider: str, override_key: str | None) -> Any:
314
+ """Return a cached client for the provider/api-key tuple, creating one if needed."""
315
+ cache_key = (provider, override_key)
316
+ async with self._client_cache_lock:
317
+ cached = self._client_cache.get(cache_key)
318
+ if cached is not None:
319
+ return cached
320
+
321
+ client_factory = PROVIDERS[provider]["async_client_getter"]
322
+ new_client = client_factory(override_key)
323
+ self._client_cache[cache_key] = new_client
324
+ return new_client
325
+
326
+ async def get_completions(
327
+ self,
328
+ *,
329
+ inputs: Sequence[MessagesInput],
330
+ model_options: list[ModelOption],
331
+ tools: list[ToolInfo] | None = None,
332
+ tool_choice: Literal["auto", "required"] | None = None,
333
+ max_new_tokens: int = 1024,
334
+ temperature: float = 1.0,
335
+ logprobs: bool = False,
336
+ top_logprobs: int | None = None,
337
+ timeout: float = 120.0,
338
+ streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
339
+ validation_callback: AsyncLLMOutputStreamingCallback | None = None,
340
+ completion_callback: AsyncLLMOutputStreamingCallback | None = None,
341
+ use_cache: bool = False,
342
+ _api_key_overrides: dict[str, str] = dict(),
343
+ ) -> list[LLMOutput]:
344
+ """Request completions from a configured LLM provider."""
345
+
346
+ # We don't support logprobs for Anthropic yet
347
+ if logprobs:
348
+ for model_option in model_options:
349
+ if model_option.provider == "anthropic":
350
+ raise ValueError(
351
+ f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
352
+ )
353
+
354
+ # Instantiate cache
355
+ # TODO(mengk): make this more robust, possibly move to a NoSQL database or something
356
+ try:
357
+ cache = LLMCache() if use_cache else None
358
+ except ValueError as e:
359
+ logger.warning(f"Disabling LLM cache due to init error: {e}")
360
+ cache = None
361
+
362
+ # Initialize pointer to which model we're using; used for model rotation after failures
363
+ current_model_option_index = 0
364
+
365
+ def _rotate_model_option() -> ModelOption | None:
366
+ nonlocal current_model_option_index
367
+
368
+ current_model_option_index += 1
369
+ if current_model_option_index >= len(model_options):
370
+ logger.error("All model options are exhausted")
371
+ return None
372
+
373
+ new_model_option = model_options[current_model_option_index]
374
+ logger.warning(f"Switched to next model {new_model_option.model_name}")
375
+ return new_model_option
376
+
377
+ while True:
378
+ # Parse the current model option
379
+ cur_option = model_options[current_model_option_index]
380
+ provider, model_name, reasoning_effort = (
381
+ cur_option.provider,
382
+ cur_option.model_name,
383
+ cur_option.reasoning_effort,
384
+ )
385
+
386
+ override_key = _api_key_overrides.get(provider)
387
+
388
+ client = await self._get_cached_client(provider, override_key)
389
+ single_output_getter = PROVIDERS[provider]["single_output_getter"]
390
+ single_streaming_output_getter = PROVIDERS[provider]["single_streaming_output_getter"]
391
+
392
+ # Get completions for uncached messages
393
+ outputs: list[LLMOutput] = await _parallelize_calls(
394
+ (
395
+ single_output_getter
396
+ if streaming_callback is None
397
+ else single_streaming_output_getter
398
+ ),
399
+ streaming_callback,
400
+ validation_callback,
401
+ completion_callback,
402
+ client,
403
+ inputs,
404
+ model_name,
405
+ tools=tools,
406
+ tool_choice=tool_choice,
407
+ max_new_tokens=max_new_tokens,
408
+ temperature=temperature,
409
+ reasoning_effort=reasoning_effort,
410
+ logprobs=logprobs,
411
+ top_logprobs=top_logprobs,
412
+ timeout=timeout,
413
+ semaphore=self._semaphore,
414
+ cache=cache,
415
+ )
416
+ assert len(outputs) == len(inputs), "Number of outputs must match number of messages"
417
+
418
+ # Only count errors that should trigger model rotation (API errors, not validation/usage errors)
419
+ num_rotation_errors = sum(
420
+ 1
421
+ for output in outputs
422
+ if output.did_error
423
+ and any(
424
+ not isinstance(e, (ValidationFailedException, DocentUsageLimitException))
425
+ for e in output.errors
426
+ )
427
+ )
428
+ if num_rotation_errors > 0:
429
+ logger.warning(f"{model_name}: {num_rotation_errors} API errors")
430
+ if not _rotate_model_option():
431
+ break
432
+ else:
433
+ break
434
+
435
+ return outputs
436
+
437
+
438
+ async def get_llm_completions_async(
439
+ inputs: list[MessagesInput],
440
+ model_options: list[ModelOption],
441
+ tools: list[ToolInfo] | None = None,
442
+ tool_choice: Literal["auto", "required"] | None = None,
443
+ max_new_tokens: int = 1024,
444
+ temperature: float = 1.0,
445
+ logprobs: bool = False,
446
+ top_logprobs: int | None = None,
447
+ timeout: float = 120.0,
448
+ streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
449
+ validation_callback: AsyncLLMOutputStreamingCallback | None = None,
450
+ completion_callback: AsyncLLMOutputStreamingCallback | None = None,
451
+ use_cache: bool = False,
452
+ _api_key_overrides: dict[str, str] = dict(),
453
+ ) -> list[LLMOutput]:
454
+ """Convenience method for backward compatibility"""
455
+
456
+ svc = BaseLLMService()
457
+ return await svc.get_completions(
458
+ inputs=inputs,
459
+ model_options=model_options,
460
+ tools=tools,
461
+ tool_choice=tool_choice,
462
+ max_new_tokens=max_new_tokens,
463
+ temperature=temperature,
464
+ logprobs=logprobs,
465
+ top_logprobs=top_logprobs,
466
+ timeout=timeout,
467
+ streaming_callback=streaming_callback,
468
+ validation_callback=validation_callback,
469
+ completion_callback=completion_callback,
470
+ use_cache=use_cache,
471
+ _api_key_overrides=_api_key_overrides,
472
+ )
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from typing import Optional
6
+
7
+ from docent._llm_util.data_models.llm_output import TokenType
8
+ from docent._log_util import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ """
14
+ Values are USD per million tokens
15
+ """
16
+ ModelRate = dict[TokenType, float]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ModelInfo:
21
+ """
22
+ Information about a model, including its rate and context window. Not to be confused with ModelOption.
23
+ """
24
+
25
+ # Values are per 1,000,000 tokens
26
+ rate: Optional[ModelRate]
27
+ # Total context window tokens
28
+ context_window: int
29
+
30
+
31
+ # Note: some providers charge extra for long prompts/outputs. We don't account for this yet.
32
+ _REGISTRY: list[tuple[str, ModelInfo]] = [
33
+ (
34
+ "gpt-5-nano",
35
+ ModelInfo(rate={"input": 0.05, "output": 0.40}, context_window=400_000),
36
+ ),
37
+ (
38
+ "gpt-5-mini",
39
+ ModelInfo(rate={"input": 0.25, "output": 2.0}, context_window=400_000),
40
+ ),
41
+ (
42
+ "gpt-5",
43
+ ModelInfo(rate={"input": 1.25, "output": 10.0}, context_window=400_000),
44
+ ),
45
+ (
46
+ "gpt-4o",
47
+ ModelInfo(rate={"input": 2.50, "output": 10.00}, context_window=100_000),
48
+ ),
49
+ (
50
+ "o4-mini",
51
+ ModelInfo(rate={"input": 1.10, "output": 4.40}, context_window=100_000),
52
+ ),
53
+ (
54
+ "claude-sonnet-4",
55
+ ModelInfo(rate={"input": 3.0, "output": 15.0}, context_window=200_000),
56
+ ),
57
+ (
58
+ "claude-haiku-4-5",
59
+ ModelInfo(rate={"input": 1.0, "output": 5.0}, context_window=200_000),
60
+ ),
61
+ (
62
+ "gemini-2.5-flash-lite",
63
+ ModelInfo(
64
+ rate={"input": 0.10, "output": 0.40},
65
+ context_window=1_000_000,
66
+ ),
67
+ ),
68
+ (
69
+ "gemini-2.5-flash",
70
+ ModelInfo(
71
+ rate={"input": 0.30, "output": 2.50},
72
+ context_window=1_000_000,
73
+ ),
74
+ ),
75
+ (
76
+ "gemini-2.5-pro",
77
+ ModelInfo(
78
+ rate={"input": 1.25, "output": 10.00},
79
+ context_window=1_000_000,
80
+ ),
81
+ ),
82
+ (
83
+ "grok-4-fast",
84
+ ModelInfo(
85
+ rate={"input": 0.20, "output": 0.50},
86
+ context_window=2_000_000,
87
+ ),
88
+ ),
89
+ (
90
+ "grok-4",
91
+ ModelInfo(
92
+ rate={"input": 3.0, "output": 15.0},
93
+ context_window=256_000,
94
+ ),
95
+ ),
96
+ ]
97
+
98
+
99
+ @lru_cache(maxsize=None)
100
+ def get_model_info(model_name: str) -> Optional[ModelInfo]:
101
+ for registry_model_name, info in _REGISTRY:
102
+ if registry_model_name in model_name:
103
+ return info
104
+ return None
105
+
106
+
107
+ def get_context_window(model_name: str) -> int:
108
+ info = get_model_info(model_name)
109
+ if info is None:
110
+ logger.warning(f"No context window found for model {model_name}")
111
+ return 100_000
112
+ return info.context_window
113
+
114
+
115
+ def get_rates_for_model_name(model_name: str) -> Optional[ModelRate]:
116
+ info = get_model_info(model_name)
117
+ return info.rate if info is not None else None
118
+
119
+
120
+ def estimate_cost_cents(model_name: str, token_count: int, token_type: TokenType) -> float:
121
+ rate = get_rates_for_model_name(model_name)
122
+ if rate is None:
123
+ logger.warning(f"No rate found for model {model_name}")
124
+ return 0.0
125
+ usd_per_mtok = rate.get(token_type)
126
+ if usd_per_mtok is None:
127
+ logger.warning(f"No rate found for model {model_name} token type {token_type}")
128
+ return 0.0
129
+ cents_per_token = usd_per_mtok * 100 / 1_000_000.0
130
+ return token_count * cents_per_token
File without changes