docent-python 0.1.20a0__tar.gz → 0.1.21a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (60) hide show
  1. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/PKG-INFO +6 -1
  2. docent_python-0.1.21a0/docent/_llm_util/data_models/exceptions.py +48 -0
  3. docent_python-0.1.21a0/docent/_llm_util/data_models/llm_output.py +320 -0
  4. docent_python-0.1.21a0/docent/_llm_util/data_models/simple_svc.py +79 -0
  5. docent_python-0.1.21a0/docent/_llm_util/llm_cache.py +193 -0
  6. docent_python-0.1.21a0/docent/_llm_util/model_registry.py +126 -0
  7. docent_python-0.1.21a0/docent/_llm_util/prod_llms.py +454 -0
  8. docent_python-0.1.21a0/docent/_llm_util/providers/__init__.py +0 -0
  9. docent_python-0.1.21a0/docent/_llm_util/providers/anthropic.py +537 -0
  10. docent_python-0.1.21a0/docent/_llm_util/providers/common.py +41 -0
  11. docent_python-0.1.21a0/docent/_llm_util/providers/google.py +530 -0
  12. docent_python-0.1.21a0/docent/_llm_util/providers/openai.py +745 -0
  13. docent_python-0.1.21a0/docent/_llm_util/providers/openrouter.py +375 -0
  14. docent_python-0.1.21a0/docent/_llm_util/providers/preference_types.py +104 -0
  15. docent_python-0.1.21a0/docent/_llm_util/providers/provider_registry.py +164 -0
  16. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/transcript.py +2 -0
  17. docent_python-0.1.21a0/docent/judges/__init__.py +21 -0
  18. docent_python-0.1.21a0/docent/judges/impl.py +222 -0
  19. docent_python-0.1.21a0/docent/judges/types.py +240 -0
  20. docent_python-0.1.21a0/docent/judges/util/forgiving_json.py +108 -0
  21. docent_python-0.1.21a0/docent/judges/util/meta_schema.json +84 -0
  22. docent_python-0.1.21a0/docent/judges/util/meta_schema.py +29 -0
  23. docent_python-0.1.21a0/docent/judges/util/parse_output.py +95 -0
  24. docent_python-0.1.21a0/docent/judges/util/voting.py +84 -0
  25. docent_python-0.1.21a0/docent/py.typed +0 -0
  26. docent_python-0.1.21a0/docent/sdk/__init__.py +0 -0
  27. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/trace.py +1 -1
  28. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/pyproject.toml +11 -5
  29. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/uv.lock +215 -5
  30. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/.gitignore +0 -0
  31. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/LICENSE.md +0 -0
  32. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/README.md +0 -0
  33. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/__init__.py +0 -0
  34. {docent_python-0.1.20a0/docent/sdk → docent_python-0.1.21a0/docent/_llm_util}/__init__.py +0 -0
  35. /docent_python-0.1.20a0/docent/py.typed → /docent_python-0.1.21a0/docent/_llm_util/data_models/__init__.py +0 -0
  36. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/_log_util/__init__.py +0 -0
  37. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/_log_util/logger.py +0 -0
  38. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/__init__.py +0 -0
  39. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/_tiktoken_util.py +0 -0
  40. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/agent_run.py +0 -0
  41. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/chat/__init__.py +0 -0
  42. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/chat/content.py +0 -0
  43. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/chat/message.py +0 -0
  44. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/chat/tool.py +0 -0
  45. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/citation.py +0 -0
  46. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/judge.py +0 -0
  47. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/metadata_util.py +0 -0
  48. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/regex.py +0 -0
  49. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/remove_invalid_citation_ranges.py +0 -0
  50. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/shared_types.py +0 -0
  51. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/data_models/util.py +0 -0
  52. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/loaders/load_inspect.py +0 -0
  53. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/samples/__init__.py +0 -0
  54. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/samples/load.py +0 -0
  55. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/samples/log.eval +0 -0
  56. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/samples/tb_airline.json +0 -0
  57. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/sdk/agent_run_writer.py +0 -0
  58. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/sdk/client.py +0 -0
  59. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/trace_2.py +0 -0
  60. {docent_python-0.1.20a0 → docent_python-0.1.21a0}/docent/trace_temp.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: docent-python
3
- Version: 0.1.20a0
3
+ Version: 0.1.21a0
4
4
  Summary: Docent SDK
5
5
  Project-URL: Homepage, https://github.com/TransluceAI/docent
6
6
  Project-URL: Issues, https://github.com/TransluceAI/docent/issues
@@ -9,8 +9,12 @@ Author-email: Transluce <info@transluce.org>
9
9
  License-Expression: Apache-2.0
10
10
  License-File: LICENSE.md
11
11
  Requires-Python: >=3.11
12
+ Requires-Dist: anthropic>=0.47.0
12
13
  Requires-Dist: backoff>=2.2.1
14
+ Requires-Dist: google-genai>=1.16.1
13
15
  Requires-Dist: inspect-ai>=0.3.132
16
+ Requires-Dist: jsonschema>=4.24.0
17
+ Requires-Dist: openai>=1.68.0
14
18
  Requires-Dist: opentelemetry-api>=1.34.1
15
19
  Requires-Dist: opentelemetry-exporter-otlp-proto-grpc>=1.34.1
16
20
  Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.34.1
@@ -20,6 +24,7 @@ Requires-Dist: opentelemetry-instrumentation-langchain>=0.40.14
20
24
  Requires-Dist: opentelemetry-instrumentation-openai>=0.40.14
21
25
  Requires-Dist: opentelemetry-instrumentation-threading>=0.55b1
22
26
  Requires-Dist: opentelemetry-sdk>=1.34.1
27
+ Requires-Dist: orjson>=3.11.3
23
28
  Requires-Dist: pydantic>=2.11.7
24
29
  Requires-Dist: pyyaml>=6.0.2
25
30
  Requires-Dist: tiktoken>=0.7.0
@@ -0,0 +1,48 @@
1
+ class LLMException(Exception):
2
+ error_type_id = "other"
3
+ user_message = "The model failed to respond. Please try again later."
4
+
5
+
6
+ class CompletionTooLongException(LLMException):
7
+ error_type_id = "completion_too_long"
8
+ user_message = "Completion too long."
9
+
10
+
11
+ class RateLimitException(LLMException):
12
+ error_type_id = "rate_limit"
13
+ user_message = "Rate limited by the model provider. Please wait and try again."
14
+
15
+
16
+ class ContextWindowException(LLMException):
17
+ error_type_id = "context_window"
18
+ user_message = "Context window exceeded."
19
+
20
+
21
+ class NoResponseException(LLMException):
22
+ error_type_id = "no_response"
23
+ user_message = "The model returned an empty response. Please try again later."
24
+
25
+
26
+ class DocentUsageLimitException(LLMException):
27
+ error_type_id = "docent_usage_limit"
28
+ user_message = "Free daily usage limit reached. Add your own API key in settings or contact us for increased limits."
29
+
30
+
31
+ class ValidationFailedException(LLMException):
32
+ error_type_id = "validation_failed"
33
+ user_message = "The model returned invalid output that failed validation."
34
+
35
+ def __init__(self, message: str = "", failed_output: str | None = None):
36
+ super().__init__(message)
37
+ self.failed_output = failed_output
38
+
39
+
40
+ LLM_ERROR_TYPES: list[type[LLMException]] = [
41
+ LLMException,
42
+ CompletionTooLongException,
43
+ RateLimitException,
44
+ ContextWindowException,
45
+ NoResponseException,
46
+ DocentUsageLimitException,
47
+ ValidationFailedException,
48
+ ]
@@ -0,0 +1,320 @@
1
+ import json
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Literal, Protocol, cast
4
+
5
+ from openai.types.chat.chat_completion_token_logprob import TopLogprob
6
+ from pydantic import BaseModel
7
+
8
+ from docent._llm_util.data_models.exceptions import (
9
+ LLM_ERROR_TYPES,
10
+ CompletionTooLongException,
11
+ LLMException,
12
+ )
13
+ from docent._log_util import get_logger
14
+ from docent.data_models.chat import ToolCall
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ FinishReasonType = Literal[
19
+ "error",
20
+ "stop",
21
+ "length",
22
+ "tool_calls",
23
+ "content_filter",
24
+ "function_call",
25
+ "streaming",
26
+ "refusal",
27
+ ]
28
+ """Possible reasons for an LLM completion to finish."""
29
+
30
+
31
+ TokenType = Literal["input", "output", "cache_read", "cache_write"]
32
+
33
+
34
+ class UsageMetrics:
35
+ _usage: dict[TokenType, int]
36
+
37
+ def __init__(self, **kwargs: int | None):
38
+ filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
39
+ self._usage = cast(dict[TokenType, int], filtered_kwargs)
40
+
41
+ def __getitem__(self, key: TokenType) -> int:
42
+ return self._usage.get(key, 0)
43
+
44
+ def __setitem__(self, key: TokenType, value: int):
45
+ self._usage[key] = value
46
+
47
+ def to_dict(self) -> dict[TokenType, int]:
48
+ # Filter out 0 values to avoid cluttering the database
49
+ return {k: v for k, v in self._usage.items() if v != 0}
50
+
51
+ @property
52
+ def total_tokens(self) -> int:
53
+ return self["input"] + self["output"]
54
+
55
+
56
+ class LLMCompletion(BaseModel):
57
+ """A single completion from an LLM.
58
+
59
+ Attributes:
60
+ text: The generated text content.
61
+ tool_calls: List of tool calls made during the completion.
62
+ finish_reason: Reason why the completion finished.
63
+ top_logprobs: Probability distribution for top token choices.
64
+ """
65
+
66
+ text: str | None = None
67
+ tool_calls: list[ToolCall] | None = None
68
+ finish_reason: FinishReasonType | None = None
69
+ top_logprobs: list[list[TopLogprob]] | None = None
70
+ reasoning_tokens: str | None = None
71
+
72
+ @property
73
+ def no_text(self) -> bool:
74
+ """Check if the completion has no text.
75
+
76
+ Returns:
77
+ bool: True if text is None or empty, False otherwise.
78
+ """
79
+ return self.text is None or len(self.text) == 0
80
+
81
+
82
+ @dataclass
83
+ class LLMOutput:
84
+ """Container for LLM output, potentially with multiple completions.
85
+
86
+ Aggregates completions from an LLM along with metadata and error information.
87
+
88
+ Attributes:
89
+ model: The name/identifier of the model used.
90
+ completions: List of individual completions.
91
+ errors: List of error types encountered during generation.
92
+ """
93
+
94
+ model: str
95
+ completions: list[LLMCompletion]
96
+ errors: list[LLMException] = field(default_factory=list)
97
+ usage: UsageMetrics = field(default_factory=UsageMetrics)
98
+ from_cache: bool = False
99
+
100
+ @property
101
+ def non_empty(self) -> bool:
102
+ """Check if there are any completions.
103
+
104
+ Returns:
105
+ bool: True if there's at least one completion, False otherwise.
106
+ """
107
+ return len(self.completions) > 0
108
+
109
+ @property
110
+ def first(self) -> LLMCompletion | None:
111
+ """Get the first completion if available.
112
+
113
+ Returns:
114
+ LLMCompletion | None: The first completion or None if no completions exist.
115
+ """
116
+ return self.completions[0] if self.non_empty else None
117
+
118
+ @property
119
+ def first_text(self) -> str | None:
120
+ """Get the text of the first completion if available.
121
+
122
+ Returns:
123
+ str | None: The text of the first completion or None if no completion exists.
124
+ """
125
+ return self.first.text if self.first else None
126
+
127
+ @property
128
+ def did_error(self) -> bool:
129
+ """Check if any errors occurred during generation.
130
+
131
+ Returns:
132
+ bool: True if there were errors, False otherwise.
133
+ """
134
+ return bool(self.errors)
135
+
136
+ def to_dict(self) -> dict[str, Any]:
137
+ return {
138
+ "model": self.model,
139
+ "completions": [comp.model_dump() for comp in self.completions],
140
+ "errors": [e.error_type_id for e in self.errors],
141
+ "usage": self.usage.to_dict(),
142
+ "from_cache": self.from_cache,
143
+ }
144
+
145
+ @classmethod
146
+ def from_dict(cls, data: dict[str, Any]) -> "LLMOutput":
147
+ error_type_map = {e.error_type_id: e for e in LLM_ERROR_TYPES}
148
+ errors = data.get("errors", [])
149
+ errors = [error_type_map.get(e, LLMException)() for e in errors]
150
+
151
+ completions = data.get("completions", [])
152
+ completions = [LLMCompletion.model_validate(comp) for comp in completions]
153
+
154
+ usage: dict[TokenType, int] = {}
155
+ if data_usage := data.get("usage"):
156
+ usage = cast(dict[TokenType, int], data_usage)
157
+
158
+ return cls(
159
+ model=data["model"],
160
+ completions=completions,
161
+ errors=errors,
162
+ usage=UsageMetrics(**usage),
163
+ from_cache=bool(data.get("from_cache", False)),
164
+ )
165
+
166
+
167
+ @dataclass
168
+ class ToolCallPartial:
169
+ """Partial representation of a tool call before full processing.
170
+
171
+ Used as an intermediate format before finalizing into a complete ToolCall.
172
+
173
+ Args:
174
+ id: The identifier for the tool call.
175
+ function: The name of the function to call.
176
+ arguments_raw: Raw JSON string of arguments for the function.
177
+ type: The type of the tool call, always "function".
178
+ """
179
+
180
+ id: str | None
181
+ function: str | None
182
+ arguments_raw: str | None
183
+ type: Literal["function"]
184
+
185
+
186
+ class LLMCompletionPartial(LLMCompletion):
187
+ """Partial representation of an LLM completion before finalization.
188
+
189
+ Extends LLMCompletion but with tool_calls being a list of ToolCallPartial.
190
+ This is used during the processing stage before tool calls are fully parsed.
191
+
192
+ Attributes:
193
+ tool_calls: List of partial tool call representations.
194
+ """
195
+
196
+ tool_calls: list[ToolCallPartial | None] | None = None # type: ignore
197
+
198
+
199
+ class LLMOutputPartial(LLMOutput):
200
+ """Partial representation of LLM output before finalization.
201
+
202
+ Extends LLMOutput but with completions being a list of LLMCompletionPartial.
203
+ Used as an intermediate format during processing.
204
+
205
+ Attributes:
206
+ completions: List of partial completions.
207
+ """
208
+
209
+ completions: list[LLMCompletionPartial] # type: ignore
210
+
211
+
212
+ def finalize_llm_output_partial(partial: LLMOutputPartial) -> LLMOutput:
213
+ """Convert a partial LLM output into a finalized LLM output.
214
+
215
+ Processes tool calls by parsing their arguments from raw JSON strings,
216
+ handles errors in JSON parsing, and provides warnings for truncated completions.
217
+
218
+ Args:
219
+ partial: The partial LLM output to finalize.
220
+
221
+ Returns:
222
+ LLMOutput: The finalized LLM output with processed tool calls.
223
+
224
+ Raises:
225
+ CompletionTooLongException: If the completion was truncated due to length
226
+ and resulted in empty text.
227
+ ValueError: If tool call ID or function is missing in the partial data.
228
+ """
229
+
230
+ def _parse_tool_call(tc_partial: ToolCallPartial):
231
+ if tc_partial.id is None:
232
+ raise ValueError("Tool call ID not found in partial; check for parsing errors")
233
+ if tc_partial.function is None:
234
+ raise ValueError("Tool call function not found in partial; check for parsing errors")
235
+
236
+ arguments: dict[str, Any] = {}
237
+ # Attempt to load arguments into JSON
238
+ try:
239
+ arguments = json.loads(tc_partial.arguments_raw or "{}")
240
+ parse_error = None
241
+ # If the tool call arguments are not valid JSON, return an empty dict with the error
242
+ except Exception as e:
243
+ arguments = {"__parse_error_raw_args": tc_partial.arguments_raw}
244
+ parse_error = f"Couldn't parse tool call arguments as JSON: {e}. Original input: {tc_partial.arguments_raw}"
245
+
246
+ return ToolCall(
247
+ id=tc_partial.id,
248
+ function=tc_partial.function,
249
+ arguments=arguments,
250
+ parse_error=parse_error,
251
+ type=tc_partial.type,
252
+ )
253
+
254
+ output = LLMOutput(
255
+ model=partial.model,
256
+ completions=[
257
+ LLMCompletion(
258
+ text=c.text,
259
+ tool_calls=[_parse_tool_call(tc) for tc in (c.tool_calls or []) if tc is not None],
260
+ finish_reason=c.finish_reason,
261
+ reasoning_tokens=c.reasoning_tokens,
262
+ )
263
+ for c in partial.completions
264
+ ],
265
+ usage=partial.usage,
266
+ from_cache=False,
267
+ )
268
+
269
+ # If the completion is empty and was truncated (likely due to too much reasoning), raise an exception
270
+ if output.first and output.first.finish_reason == "length" and output.first.no_text:
271
+ raise CompletionTooLongException(
272
+ "Completion empty due to truncation. Consider increasing max_new_tokens."
273
+ )
274
+ for c in output.completions:
275
+ if c.finish_reason == "length":
276
+ logger.warning(
277
+ "Completion truncated due to length; consider increasing max_new_tokens."
278
+ )
279
+
280
+ return output
281
+
282
+
283
+ class AsyncLLMOutputStreamingCallback(Protocol):
284
+ """Protocol for asynchronous streaming callbacks with batch index.
285
+
286
+ Defines the expected signature for callbacks that handle streaming output
287
+ with a batch index.
288
+
289
+ Args:
290
+ batch_index: The index of the current batch.
291
+ llm_output: The LLM output for the current batch.
292
+ """
293
+
294
+ async def __call__(
295
+ self,
296
+ batch_index: int,
297
+ llm_output: LLMOutput,
298
+ ) -> None: ...
299
+
300
+
301
+ class AsyncSingleLLMOutputStreamingCallback(Protocol):
302
+ """Protocol for asynchronous streaming callbacks without batch indexing.
303
+
304
+ Defines the expected signature for callbacks that handle streaming output
305
+ without batch indexing.
306
+
307
+ Args:
308
+ llm_output: The LLM output to process.
309
+ """
310
+
311
+ async def __call__(
312
+ self,
313
+ llm_output: LLMOutput,
314
+ ) -> None: ...
315
+
316
+
317
+ class AsyncEmbeddingStreamingCallback(Protocol):
318
+ """Protocol for sending progress updates for embedding generation."""
319
+
320
+ async def __call__(self, progress: int) -> None: ...
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Literal
5
+
6
+ from docent._llm_util.data_models.llm_output import (
7
+ AsyncLLMOutputStreamingCallback,
8
+ LLMOutput,
9
+ )
10
+ from docent._llm_util.prod_llms import MessagesInput, get_llm_completions_async
11
+ from docent._llm_util.providers.preference_types import ModelOption
12
+ from docent.data_models.chat import ToolInfo
13
+
14
+ __all__ = ["BaseLLMService"]
15
+
16
+
17
+ class BaseLLMService(ABC):
18
+ """Common interface for LLM services."""
19
+
20
+ @abstractmethod
21
+ async def get_completions(
22
+ self,
23
+ *,
24
+ inputs: list[MessagesInput],
25
+ model_options: list[ModelOption],
26
+ tools: list[ToolInfo] | None = None,
27
+ tool_choice: Literal["auto", "required"] | None = None,
28
+ max_new_tokens: int = 1024,
29
+ temperature: float = 1.0,
30
+ logprobs: bool = False,
31
+ top_logprobs: int | None = None,
32
+ max_concurrency: int = 100,
33
+ timeout: float = 120.0,
34
+ streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
35
+ validation_callback: AsyncLLMOutputStreamingCallback | None = None,
36
+ completion_callback: AsyncLLMOutputStreamingCallback | None = None,
37
+ use_cache: bool = False,
38
+ ) -> list[LLMOutput]:
39
+ """Request completions from a configured LLM provider."""
40
+
41
+
42
+ class SimpleLLMService(BaseLLMService):
43
+ """Lightweight LLM service that simply forwards completion requests.
44
+ Does not support cost tracking, usage limits, global scheduling or rate limiting."""
45
+
46
+ async def get_completions(
47
+ self,
48
+ *,
49
+ inputs: list[MessagesInput],
50
+ model_options: list[ModelOption],
51
+ tools: list[ToolInfo] | None = None,
52
+ tool_choice: Literal["auto", "required"] | None = None,
53
+ max_new_tokens: int = 1024,
54
+ temperature: float = 1.0,
55
+ logprobs: bool = False,
56
+ top_logprobs: int | None = None,
57
+ max_concurrency: int = 100,
58
+ timeout: float = 120.0,
59
+ streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
60
+ validation_callback: AsyncLLMOutputStreamingCallback | None = None,
61
+ completion_callback: AsyncLLMOutputStreamingCallback | None = None,
62
+ use_cache: bool = False,
63
+ ) -> list[LLMOutput]:
64
+ return await get_llm_completions_async(
65
+ inputs=inputs,
66
+ model_options=model_options,
67
+ tools=tools,
68
+ tool_choice=tool_choice,
69
+ max_new_tokens=max_new_tokens,
70
+ temperature=temperature,
71
+ logprobs=logprobs,
72
+ top_logprobs=top_logprobs,
73
+ max_concurrency=max_concurrency,
74
+ timeout=timeout,
75
+ streaming_callback=streaming_callback,
76
+ validation_callback=validation_callback,
77
+ completion_callback=completion_callback,
78
+ use_cache=use_cache,
79
+ )
@@ -0,0 +1,193 @@
1
+ import hashlib
2
+ import json
3
+ import os
4
+ import sqlite3
5
+ from contextlib import contextmanager
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ from docent._llm_util.data_models.llm_output import LLMOutput
10
+ from docent._log_util import get_logger
11
+ from docent.data_models.chat import ChatMessage, ToolInfo
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class LLMCache:
17
+ def __init__(self, db_path: str | None = None):
18
+ if db_path is None:
19
+ llm_cache_path = os.getenv("LLM_CACHE_PATH")
20
+ if llm_cache_path is None or llm_cache_path == "":
21
+ raise ValueError("LLM_CACHE_PATH is not set")
22
+ else:
23
+ cache_dir = Path(llm_cache_path)
24
+ cache_dir.mkdir(parents=True, exist_ok=True)
25
+ db_path = str(cache_dir / "llm_cache.db")
26
+
27
+ self.db_path = db_path
28
+ self._init_db()
29
+
30
+ def _init_db(self) -> None:
31
+ with self._get_connection() as conn:
32
+ conn.execute(
33
+ """
34
+ CREATE TABLE IF NOT EXISTS llm_cache (
35
+ key TEXT PRIMARY KEY,
36
+ completion TEXT,
37
+ model_name TEXT,
38
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
39
+ )
40
+ """
41
+ )
42
+
43
+ @contextmanager
44
+ def _get_connection(self):
45
+ conn = sqlite3.connect(self.db_path)
46
+ try:
47
+ yield conn
48
+ finally:
49
+ conn.close()
50
+
51
+ def _create_key(
52
+ self,
53
+ messages: list[ChatMessage],
54
+ model_name: str,
55
+ *,
56
+ tools: list[ToolInfo] | None = None,
57
+ tool_choice: Literal["auto", "required"] | None = None,
58
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
59
+ temperature: float = 1.0,
60
+ logprobs: bool = False,
61
+ top_logprobs: int | None = None,
62
+ ) -> str:
63
+ """Create a deterministic hash key from messages and model."""
64
+ # Convert messages to a stable string representation
65
+ message_str = json.dumps(
66
+ [msg.model_dump(exclude={"id"}) for msg in messages], sort_keys=True
67
+ )
68
+
69
+ # Convert tools to a stable string representation if present
70
+ tools_str = (
71
+ json.dumps([tool.model_dump() for tool in tools], sort_keys=True) if tools else None
72
+ )
73
+
74
+ # Combine all parameters into a single string
75
+ key_str = (
76
+ f"{message_str}:{model_name}:{tools_str}:{tool_choice}:{reasoning_effort}:{temperature}"
77
+ )
78
+ if logprobs:
79
+ key_str += f":{top_logprobs}"
80
+ return hashlib.sha256(key_str.encode()).hexdigest()
81
+
82
+ def get(
83
+ self,
84
+ messages: list[ChatMessage],
85
+ model_name: str,
86
+ *,
87
+ tools: list[ToolInfo] | None = None,
88
+ tool_choice: Literal["auto", "required"] | None = None,
89
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
90
+ temperature: float = 1.0,
91
+ logprobs: bool = False,
92
+ top_logprobs: int | None = None,
93
+ ) -> LLMOutput | None:
94
+ """Get cached completion for a conversation if it exists."""
95
+
96
+ key = self._create_key(
97
+ messages,
98
+ model_name,
99
+ tools=tools,
100
+ tool_choice=tool_choice,
101
+ reasoning_effort=reasoning_effort,
102
+ temperature=temperature,
103
+ logprobs=logprobs,
104
+ top_logprobs=top_logprobs,
105
+ )
106
+
107
+ with self._get_connection() as conn:
108
+ cursor = conn.execute("SELECT completion FROM llm_cache WHERE key = ?", (key,))
109
+ result = cursor.fetchone()
110
+ if not result:
111
+ return None
112
+ out = LLMOutput.from_dict(json.loads(result[0]))
113
+ out.from_cache = True
114
+ return out
115
+
116
+ def set(
117
+ self,
118
+ messages: list[ChatMessage],
119
+ model_name: str,
120
+ llm_output: LLMOutput,
121
+ *,
122
+ tools: list[ToolInfo] | None = None,
123
+ tool_choice: Literal["auto", "required"] | None = None,
124
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
125
+ temperature: float = 1.0,
126
+ logprobs: bool = False,
127
+ top_logprobs: int | None = None,
128
+ ) -> None:
129
+ """Cache a completion for a conversation."""
130
+
131
+ key = self._create_key(
132
+ messages,
133
+ model_name,
134
+ tools=tools,
135
+ tool_choice=tool_choice,
136
+ reasoning_effort=reasoning_effort,
137
+ temperature=temperature,
138
+ logprobs=logprobs,
139
+ top_logprobs=top_logprobs,
140
+ )
141
+
142
+ with self._get_connection() as conn:
143
+ conn.execute(
144
+ "INSERT OR REPLACE INTO llm_cache (key, completion, model_name) VALUES (?, ?, ?)",
145
+ (key, json.dumps(llm_output.to_dict()), model_name),
146
+ )
147
+ conn.commit()
148
+
149
+ def set_batch(
150
+ self,
151
+ messages_list: list[list[ChatMessage]],
152
+ model_name: str,
153
+ llm_output_list: list[LLMOutput],
154
+ *,
155
+ tools: list[ToolInfo] | None = None,
156
+ tool_choice: Literal["auto", "required"] | None = None,
157
+ reasoning_effort: Literal["low", "medium", "high"] | None = None,
158
+ temperature: float = 1.0,
159
+ logprobs: bool = False,
160
+ top_logprobs: int | None = None,
161
+ ) -> None:
162
+ """Cache a completion for a conversation."""
163
+
164
+ keys: list[str] = []
165
+ for messages in messages_list:
166
+ key = self._create_key(
167
+ messages,
168
+ model_name,
169
+ tools=tools,
170
+ tool_choice=tool_choice,
171
+ reasoning_effort=reasoning_effort,
172
+ temperature=temperature,
173
+ logprobs=logprobs,
174
+ top_logprobs=top_logprobs,
175
+ )
176
+ keys.append(key)
177
+
178
+ with self._get_connection() as conn:
179
+ conn.executemany(
180
+ "INSERT OR REPLACE INTO llm_cache (key, completion, model_name) VALUES (?, ?, ?)",
181
+ [
182
+ (key, json.dumps(llm_output.to_dict()), model_name)
183
+ for key, llm_output in zip(keys, llm_output_list)
184
+ ],
185
+ )
186
+ conn.commit()
187
+
188
+ def clear(self) -> None:
189
+ """Clear all cached completions."""
190
+
191
+ with self._get_connection() as conn:
192
+ conn.execute("DELETE FROM llm_cache")
193
+ conn.commit()