docent-python 0.1.19a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +130 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/__init__.py +2 -2
- docent/data_models/agent_run.py +1 -0
- docent/data_models/judge.py +7 -4
- docent/data_models/transcript.py +2 -0
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +311 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +87 -0
- docent/judges/util/voting.py +139 -0
- docent/sdk/client.py +181 -44
- docent/trace.py +362 -44
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
- docent_python-0.1.27a0.dist-info/RECORD +59 -0
- docent_python-0.1.19a0.dist-info/RECORD +0 -32
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
class LLMException(Exception):
|
|
2
|
+
error_type_id = "other"
|
|
3
|
+
user_message = "The model failed to respond. Please try again later."
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CompletionTooLongException(LLMException):
|
|
7
|
+
error_type_id = "completion_too_long"
|
|
8
|
+
user_message = "Completion too long."
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RateLimitException(LLMException):
|
|
12
|
+
error_type_id = "rate_limit"
|
|
13
|
+
user_message = "Rate limited by the model provider. Please wait and try again."
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ContextWindowException(LLMException):
|
|
17
|
+
error_type_id = "context_window"
|
|
18
|
+
user_message = "Context window exceeded."
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NoResponseException(LLMException):
|
|
22
|
+
error_type_id = "no_response"
|
|
23
|
+
user_message = "The model returned an empty response. Please try again later."
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DocentUsageLimitException(LLMException):
|
|
27
|
+
error_type_id = "docent_usage_limit"
|
|
28
|
+
user_message = "Free daily usage limit reached. Add your own API key in settings or contact us for increased limits."
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ValidationFailedException(LLMException):
|
|
32
|
+
error_type_id = "validation_failed"
|
|
33
|
+
user_message = "The model returned invalid output that failed validation."
|
|
34
|
+
|
|
35
|
+
def __init__(self, message: str = "", failed_output: str | None = None):
|
|
36
|
+
super().__init__(message)
|
|
37
|
+
self.failed_output = failed_output
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
LLM_ERROR_TYPES: list[type[LLMException]] = [
|
|
41
|
+
LLMException,
|
|
42
|
+
CompletionTooLongException,
|
|
43
|
+
RateLimitException,
|
|
44
|
+
ContextWindowException,
|
|
45
|
+
NoResponseException,
|
|
46
|
+
DocentUsageLimitException,
|
|
47
|
+
ValidationFailedException,
|
|
48
|
+
]
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Literal, Protocol, cast
|
|
4
|
+
|
|
5
|
+
from openai.types.chat.chat_completion_token_logprob import TopLogprob
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from docent._llm_util.data_models.exceptions import (
|
|
9
|
+
LLM_ERROR_TYPES,
|
|
10
|
+
CompletionTooLongException,
|
|
11
|
+
ContextWindowException,
|
|
12
|
+
LLMException,
|
|
13
|
+
)
|
|
14
|
+
from docent._log_util import get_logger
|
|
15
|
+
from docent.data_models.chat import ToolCall
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
FinishReasonType = Literal[
|
|
20
|
+
"error",
|
|
21
|
+
"stop",
|
|
22
|
+
"length",
|
|
23
|
+
"tool_calls",
|
|
24
|
+
"content_filter",
|
|
25
|
+
"function_call",
|
|
26
|
+
"streaming",
|
|
27
|
+
"refusal",
|
|
28
|
+
]
|
|
29
|
+
"""Possible reasons for an LLM completion to finish."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
TokenType = Literal["input", "output", "cache_read", "cache_write"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class UsageMetrics:
|
|
36
|
+
_usage: dict[TokenType, int]
|
|
37
|
+
|
|
38
|
+
def __init__(self, **kwargs: int | None):
|
|
39
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
40
|
+
self._usage = cast(dict[TokenType, int], filtered_kwargs)
|
|
41
|
+
|
|
42
|
+
def __getitem__(self, key: TokenType) -> int:
|
|
43
|
+
return self._usage.get(key, 0)
|
|
44
|
+
|
|
45
|
+
def __setitem__(self, key: TokenType, value: int):
|
|
46
|
+
self._usage[key] = value
|
|
47
|
+
|
|
48
|
+
def to_dict(self) -> dict[TokenType, int]:
|
|
49
|
+
# Filter out 0 values to avoid cluttering the database
|
|
50
|
+
return {k: v for k, v in self._usage.items() if v != 0}
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def total_tokens(self) -> int:
|
|
54
|
+
return self["input"] + self["output"]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class LLMCompletion(BaseModel):
|
|
58
|
+
"""A single completion from an LLM.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
text: The generated text content.
|
|
62
|
+
tool_calls: List of tool calls made during the completion.
|
|
63
|
+
finish_reason: Reason why the completion finished.
|
|
64
|
+
top_logprobs: Probability distribution for top token choices.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
text: str | None = None
|
|
68
|
+
tool_calls: list[ToolCall] | None = None
|
|
69
|
+
finish_reason: FinishReasonType | None = None
|
|
70
|
+
top_logprobs: list[list[TopLogprob]] | None = None
|
|
71
|
+
reasoning_tokens: str | None = None
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def no_text(self) -> bool:
|
|
75
|
+
"""Check if the completion has no text.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
bool: True if text is None or empty, False otherwise.
|
|
79
|
+
"""
|
|
80
|
+
return self.text is None or len(self.text) == 0
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class LLMOutput:
|
|
85
|
+
"""Container for LLM output, potentially with multiple completions.
|
|
86
|
+
|
|
87
|
+
Aggregates completions from an LLM along with metadata and error information.
|
|
88
|
+
|
|
89
|
+
Attributes:
|
|
90
|
+
model: The name/identifier of the model used.
|
|
91
|
+
completions: List of individual completions.
|
|
92
|
+
errors: List of error types encountered during generation.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
model: str
|
|
96
|
+
completions: list[LLMCompletion]
|
|
97
|
+
errors: list[LLMException] = field(default_factory=list)
|
|
98
|
+
usage: UsageMetrics = field(default_factory=UsageMetrics)
|
|
99
|
+
from_cache: bool = False
|
|
100
|
+
duration: float | None = None
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def non_empty(self) -> bool:
|
|
104
|
+
"""Check if there are any completions.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
bool: True if there's at least one completion, False otherwise.
|
|
108
|
+
"""
|
|
109
|
+
return len(self.completions) > 0
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def first(self) -> LLMCompletion | None:
|
|
113
|
+
"""Get the first completion if available.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
LLMCompletion | None: The first completion or None if no completions exist.
|
|
117
|
+
"""
|
|
118
|
+
return self.completions[0] if self.non_empty else None
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def first_text(self) -> str | None:
|
|
122
|
+
"""Get the text of the first completion if available.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
str | None: The text of the first completion or None if no completion exists.
|
|
126
|
+
"""
|
|
127
|
+
return self.first.text if self.first else None
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def did_error(self) -> bool:
|
|
131
|
+
"""Check if any errors occurred during generation.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
bool: True if there were errors, False otherwise.
|
|
135
|
+
"""
|
|
136
|
+
return bool(self.errors)
|
|
137
|
+
|
|
138
|
+
def to_dict(self) -> dict[str, Any]:
|
|
139
|
+
return {
|
|
140
|
+
"model": self.model,
|
|
141
|
+
"completions": [comp.model_dump() for comp in self.completions],
|
|
142
|
+
"errors": [e.error_type_id for e in self.errors],
|
|
143
|
+
"usage": self.usage.to_dict(),
|
|
144
|
+
"from_cache": self.from_cache,
|
|
145
|
+
"duration": self.duration,
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def from_dict(cls, data: dict[str, Any]) -> "LLMOutput":
|
|
150
|
+
error_type_map = {e.error_type_id: e for e in LLM_ERROR_TYPES}
|
|
151
|
+
errors = data.get("errors", [])
|
|
152
|
+
error_types_to_not_log: list[str] = [
|
|
153
|
+
CompletionTooLongException.error_type_id,
|
|
154
|
+
ContextWindowException.error_type_id,
|
|
155
|
+
]
|
|
156
|
+
errors_to_log = [e for e in errors if e not in error_types_to_not_log]
|
|
157
|
+
if errors_to_log:
|
|
158
|
+
logger.error(f"Loading LLM output with errors: {errors}")
|
|
159
|
+
errors = [error_type_map.get(e, LLMException)() for e in errors]
|
|
160
|
+
|
|
161
|
+
completions = data.get("completions", [])
|
|
162
|
+
completions = [LLMCompletion.model_validate(comp) for comp in completions]
|
|
163
|
+
|
|
164
|
+
usage: dict[TokenType, int] = {}
|
|
165
|
+
if data_usage := data.get("usage"):
|
|
166
|
+
usage = cast(dict[TokenType, int], data_usage)
|
|
167
|
+
|
|
168
|
+
return cls(
|
|
169
|
+
model=data["model"],
|
|
170
|
+
completions=completions,
|
|
171
|
+
errors=errors,
|
|
172
|
+
usage=UsageMetrics(**usage),
|
|
173
|
+
from_cache=bool(data.get("from_cache", False)),
|
|
174
|
+
duration=data.get("duration"),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@dataclass
|
|
179
|
+
class ToolCallPartial:
|
|
180
|
+
"""Partial representation of a tool call before full processing.
|
|
181
|
+
|
|
182
|
+
Used as an intermediate format before finalizing into a complete ToolCall.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
id: The identifier for the tool call.
|
|
186
|
+
function: The name of the function to call.
|
|
187
|
+
arguments_raw: Raw JSON string of arguments for the function.
|
|
188
|
+
type: The type of the tool call, always "function".
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
id: str | None
|
|
192
|
+
function: str | None
|
|
193
|
+
arguments_raw: str | None
|
|
194
|
+
type: Literal["function"]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class LLMCompletionPartial(LLMCompletion):
|
|
198
|
+
"""Partial representation of an LLM completion before finalization.
|
|
199
|
+
|
|
200
|
+
Extends LLMCompletion but with tool_calls being a list of ToolCallPartial.
|
|
201
|
+
This is used during the processing stage before tool calls are fully parsed.
|
|
202
|
+
|
|
203
|
+
Attributes:
|
|
204
|
+
tool_calls: List of partial tool call representations.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
tool_calls: list[ToolCallPartial | None] | None = None # type: ignore
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class LLMOutputPartial(LLMOutput):
|
|
211
|
+
"""Partial representation of LLM output before finalization.
|
|
212
|
+
|
|
213
|
+
Extends LLMOutput but with completions being a list of LLMCompletionPartial.
|
|
214
|
+
Used as an intermediate format during processing.
|
|
215
|
+
|
|
216
|
+
Attributes:
|
|
217
|
+
completions: List of partial completions.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
completions: list[LLMCompletionPartial] # type: ignore
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def finalize_llm_output_partial(partial: LLMOutputPartial) -> LLMOutput:
|
|
224
|
+
"""Convert a partial LLM output into a finalized LLM output.
|
|
225
|
+
|
|
226
|
+
Processes tool calls by parsing their arguments from raw JSON strings,
|
|
227
|
+
handles errors in JSON parsing, and provides warnings for truncated completions.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
partial: The partial LLM output to finalize.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
LLMOutput: The finalized LLM output with processed tool calls.
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
CompletionTooLongException: If the completion was truncated due to length
|
|
237
|
+
and resulted in empty text.
|
|
238
|
+
ValueError: If tool call ID or function is missing in the partial data.
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def _parse_tool_call(tc_partial: ToolCallPartial):
|
|
242
|
+
if tc_partial.id is None:
|
|
243
|
+
raise ValueError("Tool call ID not found in partial; check for parsing errors")
|
|
244
|
+
if tc_partial.function is None:
|
|
245
|
+
raise ValueError("Tool call function not found in partial; check for parsing errors")
|
|
246
|
+
|
|
247
|
+
arguments: dict[str, Any] = {}
|
|
248
|
+
# Attempt to load arguments into JSON
|
|
249
|
+
try:
|
|
250
|
+
arguments = json.loads(tc_partial.arguments_raw or "{}")
|
|
251
|
+
parse_error = None
|
|
252
|
+
# If the tool call arguments are not valid JSON, return an empty dict with the error
|
|
253
|
+
except Exception as e:
|
|
254
|
+
arguments = {"__parse_error_raw_args": tc_partial.arguments_raw}
|
|
255
|
+
parse_error = f"Couldn't parse tool call arguments as JSON: {e}. Original input: {tc_partial.arguments_raw}"
|
|
256
|
+
|
|
257
|
+
return ToolCall(
|
|
258
|
+
id=tc_partial.id,
|
|
259
|
+
function=tc_partial.function,
|
|
260
|
+
arguments=arguments,
|
|
261
|
+
parse_error=parse_error,
|
|
262
|
+
type=tc_partial.type,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
output = LLMOutput(
|
|
266
|
+
model=partial.model,
|
|
267
|
+
completions=[
|
|
268
|
+
LLMCompletion(
|
|
269
|
+
text=c.text,
|
|
270
|
+
tool_calls=[_parse_tool_call(tc) for tc in (c.tool_calls or []) if tc is not None],
|
|
271
|
+
finish_reason=c.finish_reason,
|
|
272
|
+
reasoning_tokens=c.reasoning_tokens,
|
|
273
|
+
)
|
|
274
|
+
for c in partial.completions
|
|
275
|
+
],
|
|
276
|
+
usage=partial.usage,
|
|
277
|
+
from_cache=False,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# If the completion is empty and was truncated (likely due to too much reasoning), raise an exception
|
|
281
|
+
if output.first and output.first.finish_reason == "length" and output.first.no_text:
|
|
282
|
+
raise CompletionTooLongException(
|
|
283
|
+
"Completion empty due to truncation. Consider increasing max_new_tokens."
|
|
284
|
+
)
|
|
285
|
+
for c in output.completions:
|
|
286
|
+
if c.finish_reason == "length":
|
|
287
|
+
logger.warning(
|
|
288
|
+
"Completion truncated due to length; consider increasing max_new_tokens."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
return output
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class AsyncLLMOutputStreamingCallback(Protocol):
|
|
295
|
+
"""Protocol for asynchronous streaming callbacks with batch index.
|
|
296
|
+
|
|
297
|
+
Defines the expected signature for callbacks that handle streaming output
|
|
298
|
+
with a batch index.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
batch_index: The index of the current batch.
|
|
302
|
+
llm_output: The LLM output for the current batch.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
async def __call__(
|
|
306
|
+
self,
|
|
307
|
+
batch_index: int,
|
|
308
|
+
llm_output: LLMOutput,
|
|
309
|
+
) -> None: ...
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class AsyncSingleLLMOutputStreamingCallback(Protocol):
|
|
313
|
+
"""Protocol for asynchronous streaming callbacks without batch indexing.
|
|
314
|
+
|
|
315
|
+
Defines the expected signature for callbacks that handle streaming output
|
|
316
|
+
without batch indexing.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
llm_output: The LLM output to process.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
async def __call__(
|
|
323
|
+
self,
|
|
324
|
+
llm_output: LLMOutput,
|
|
325
|
+
) -> None: ...
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class AsyncEmbeddingStreamingCallback(Protocol):
|
|
329
|
+
"""Protocol for sending progress updates for embedding generation."""
|
|
330
|
+
|
|
331
|
+
async def __call__(self, progress: int) -> None: ...
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sqlite3
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
from docent._llm_util.data_models.llm_output import LLMOutput
|
|
10
|
+
from docent._log_util import get_logger
|
|
11
|
+
from docent.data_models.chat import ChatMessage, ToolInfo
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LLMCache:
|
|
17
|
+
def __init__(self, db_path: str | None = None):
|
|
18
|
+
if db_path is None:
|
|
19
|
+
llm_cache_path = os.getenv("LLM_CACHE_PATH")
|
|
20
|
+
if llm_cache_path is None or llm_cache_path == "":
|
|
21
|
+
raise ValueError("LLM_CACHE_PATH is not set")
|
|
22
|
+
else:
|
|
23
|
+
cache_dir = Path(llm_cache_path)
|
|
24
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
db_path = str(cache_dir / "llm_cache.db")
|
|
26
|
+
|
|
27
|
+
self.db_path = db_path
|
|
28
|
+
self._init_db()
|
|
29
|
+
|
|
30
|
+
def _init_db(self) -> None:
|
|
31
|
+
with self._get_connection() as conn:
|
|
32
|
+
conn.execute(
|
|
33
|
+
"""
|
|
34
|
+
CREATE TABLE IF NOT EXISTS llm_cache (
|
|
35
|
+
key TEXT PRIMARY KEY,
|
|
36
|
+
completion TEXT,
|
|
37
|
+
model_name TEXT,
|
|
38
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
39
|
+
)
|
|
40
|
+
"""
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@contextmanager
|
|
44
|
+
def _get_connection(self):
|
|
45
|
+
conn = sqlite3.connect(self.db_path)
|
|
46
|
+
try:
|
|
47
|
+
yield conn
|
|
48
|
+
finally:
|
|
49
|
+
conn.close()
|
|
50
|
+
|
|
51
|
+
def _create_key(
|
|
52
|
+
self,
|
|
53
|
+
messages: list[ChatMessage],
|
|
54
|
+
model_name: str,
|
|
55
|
+
*,
|
|
56
|
+
tools: list[ToolInfo] | None = None,
|
|
57
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
58
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
59
|
+
temperature: float = 1.0,
|
|
60
|
+
logprobs: bool = False,
|
|
61
|
+
top_logprobs: int | None = None,
|
|
62
|
+
) -> str:
|
|
63
|
+
"""Create a deterministic hash key from messages and model."""
|
|
64
|
+
# Convert messages to a stable string representation
|
|
65
|
+
message_str = json.dumps(
|
|
66
|
+
[msg.model_dump(exclude={"id"}) for msg in messages], sort_keys=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Convert tools to a stable string representation if present
|
|
70
|
+
tools_str = (
|
|
71
|
+
json.dumps([tool.model_dump() for tool in tools], sort_keys=True) if tools else None
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Combine all parameters into a single string
|
|
75
|
+
key_str = (
|
|
76
|
+
f"{message_str}:{model_name}:{tools_str}:{tool_choice}:{reasoning_effort}:{temperature}"
|
|
77
|
+
)
|
|
78
|
+
if logprobs:
|
|
79
|
+
key_str += f":{top_logprobs}"
|
|
80
|
+
return hashlib.sha256(key_str.encode()).hexdigest()
|
|
81
|
+
|
|
82
|
+
def get(
|
|
83
|
+
self,
|
|
84
|
+
messages: list[ChatMessage],
|
|
85
|
+
model_name: str,
|
|
86
|
+
*,
|
|
87
|
+
tools: list[ToolInfo] | None = None,
|
|
88
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
89
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
90
|
+
temperature: float = 1.0,
|
|
91
|
+
logprobs: bool = False,
|
|
92
|
+
top_logprobs: int | None = None,
|
|
93
|
+
) -> LLMOutput | None:
|
|
94
|
+
"""Get cached completion for a conversation if it exists."""
|
|
95
|
+
|
|
96
|
+
key = self._create_key(
|
|
97
|
+
messages,
|
|
98
|
+
model_name,
|
|
99
|
+
tools=tools,
|
|
100
|
+
tool_choice=tool_choice,
|
|
101
|
+
reasoning_effort=reasoning_effort,
|
|
102
|
+
temperature=temperature,
|
|
103
|
+
logprobs=logprobs,
|
|
104
|
+
top_logprobs=top_logprobs,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
with self._get_connection() as conn:
|
|
108
|
+
cursor = conn.execute("SELECT completion FROM llm_cache WHERE key = ?", (key,))
|
|
109
|
+
result = cursor.fetchone()
|
|
110
|
+
if not result:
|
|
111
|
+
return None
|
|
112
|
+
out = LLMOutput.from_dict(json.loads(result[0]))
|
|
113
|
+
out.from_cache = True
|
|
114
|
+
return out
|
|
115
|
+
|
|
116
|
+
def set(
|
|
117
|
+
self,
|
|
118
|
+
messages: list[ChatMessage],
|
|
119
|
+
model_name: str,
|
|
120
|
+
llm_output: LLMOutput,
|
|
121
|
+
*,
|
|
122
|
+
tools: list[ToolInfo] | None = None,
|
|
123
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
124
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
125
|
+
temperature: float = 1.0,
|
|
126
|
+
logprobs: bool = False,
|
|
127
|
+
top_logprobs: int | None = None,
|
|
128
|
+
) -> None:
|
|
129
|
+
"""Cache a completion for a conversation."""
|
|
130
|
+
|
|
131
|
+
key = self._create_key(
|
|
132
|
+
messages,
|
|
133
|
+
model_name,
|
|
134
|
+
tools=tools,
|
|
135
|
+
tool_choice=tool_choice,
|
|
136
|
+
reasoning_effort=reasoning_effort,
|
|
137
|
+
temperature=temperature,
|
|
138
|
+
logprobs=logprobs,
|
|
139
|
+
top_logprobs=top_logprobs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
with self._get_connection() as conn:
|
|
143
|
+
conn.execute(
|
|
144
|
+
"INSERT OR REPLACE INTO llm_cache (key, completion, model_name) VALUES (?, ?, ?)",
|
|
145
|
+
(key, json.dumps(llm_output.to_dict()), model_name),
|
|
146
|
+
)
|
|
147
|
+
conn.commit()
|
|
148
|
+
|
|
149
|
+
def set_batch(
|
|
150
|
+
self,
|
|
151
|
+
messages_list: list[list[ChatMessage]],
|
|
152
|
+
model_name: str,
|
|
153
|
+
llm_output_list: list[LLMOutput],
|
|
154
|
+
*,
|
|
155
|
+
tools: list[ToolInfo] | None = None,
|
|
156
|
+
tool_choice: Literal["auto", "required"] | None = None,
|
|
157
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
158
|
+
temperature: float = 1.0,
|
|
159
|
+
logprobs: bool = False,
|
|
160
|
+
top_logprobs: int | None = None,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Cache a completion for a conversation."""
|
|
163
|
+
|
|
164
|
+
keys: list[str] = []
|
|
165
|
+
for messages in messages_list:
|
|
166
|
+
key = self._create_key(
|
|
167
|
+
messages,
|
|
168
|
+
model_name,
|
|
169
|
+
tools=tools,
|
|
170
|
+
tool_choice=tool_choice,
|
|
171
|
+
reasoning_effort=reasoning_effort,
|
|
172
|
+
temperature=temperature,
|
|
173
|
+
logprobs=logprobs,
|
|
174
|
+
top_logprobs=top_logprobs,
|
|
175
|
+
)
|
|
176
|
+
keys.append(key)
|
|
177
|
+
|
|
178
|
+
with self._get_connection() as conn:
|
|
179
|
+
conn.executemany(
|
|
180
|
+
"INSERT OR REPLACE INTO llm_cache (key, completion, model_name) VALUES (?, ?, ?)",
|
|
181
|
+
[
|
|
182
|
+
(key, json.dumps(llm_output.to_dict()), model_name)
|
|
183
|
+
for key, llm_output in zip(keys, llm_output_list)
|
|
184
|
+
],
|
|
185
|
+
)
|
|
186
|
+
conn.commit()
|
|
187
|
+
|
|
188
|
+
def clear(self) -> None:
|
|
189
|
+
"""Clear all cached completions."""
|
|
190
|
+
|
|
191
|
+
with self._get_connection() as conn:
|
|
192
|
+
conn.execute("DELETE FROM llm_cache")
|
|
193
|
+
conn.commit()
|