docent-python 0.1.22a0__tar.gz → 0.1.24a0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/.gitignore +0 -1
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/PKG-INFO +1 -1
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/data_models/llm_output.py +3 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/llm_cache.py +4 -4
- docent_python-0.1.22a0/docent/_llm_util/prod_llms.py → docent_python-0.1.24a0/docent/_llm_util/llm_svc.py +104 -86
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/preference_types.py +2 -2
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/__init__.py +2 -2
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/judge.py +7 -4
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/__init__.py +2 -0
- docent_python-0.1.24a0/docent/judges/analysis.py +77 -0
- docent_python-0.1.24a0/docent/judges/impl.py +587 -0
- docent_python-0.1.24a0/docent/judges/runner.py +66 -0
- docent_python-0.1.24a0/docent/judges/stats.py +205 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/types.py +73 -2
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/util/meta_schema.json +3 -1
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/util/parse_output.py +8 -16
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/util/voting.py +38 -13
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/sdk/client.py +90 -41
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/trace.py +35 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/pyproject.toml +1 -1
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/uv.lock +18 -1
- docent_python-0.1.22a0/docent/_llm_util/data_models/simple_svc.py +0 -79
- docent_python-0.1.22a0/docent/judges/impl.py +0 -232
- docent_python-0.1.22a0/docent/trace_2.py +0 -1842
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/LICENSE.md +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/README.md +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/data_models/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/data_models/exceptions.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/model_registry.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/anthropic.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/common.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/google.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/openai.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/openrouter.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/provider_registry.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_log_util/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_log_util/logger.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/_tiktoken_util.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/agent_run.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/chat/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/chat/content.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/chat/message.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/chat/tool.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/citation.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/metadata_util.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/regex.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/remove_invalid_citation_ranges.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/shared_types.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/transcript.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/data_models/util.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/util/forgiving_json.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/judges/util/meta_schema.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/loaders/load_inspect.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/py.typed +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/samples/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/samples/load.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/samples/log.eval +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/samples/tb_airline.json +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/sdk/__init__.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/sdk/agent_run_writer.py +0 -0
- {docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/trace_temp.py +0 -0
{docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/data_models/llm_output.py
RENAMED
|
@@ -96,6 +96,7 @@ class LLMOutput:
|
|
|
96
96
|
errors: list[LLMException] = field(default_factory=list)
|
|
97
97
|
usage: UsageMetrics = field(default_factory=UsageMetrics)
|
|
98
98
|
from_cache: bool = False
|
|
99
|
+
duration: float | None = None
|
|
99
100
|
|
|
100
101
|
@property
|
|
101
102
|
def non_empty(self) -> bool:
|
|
@@ -140,6 +141,7 @@ class LLMOutput:
|
|
|
140
141
|
"errors": [e.error_type_id for e in self.errors],
|
|
141
142
|
"usage": self.usage.to_dict(),
|
|
142
143
|
"from_cache": self.from_cache,
|
|
144
|
+
"duration": self.duration,
|
|
143
145
|
}
|
|
144
146
|
|
|
145
147
|
@classmethod
|
|
@@ -161,6 +163,7 @@ class LLMOutput:
|
|
|
161
163
|
errors=errors,
|
|
162
164
|
usage=UsageMetrics(**usage),
|
|
163
165
|
from_cache=bool(data.get("from_cache", False)),
|
|
166
|
+
duration=data.get("duration"),
|
|
164
167
|
)
|
|
165
168
|
|
|
166
169
|
|
|
@@ -55,7 +55,7 @@ class LLMCache:
|
|
|
55
55
|
*,
|
|
56
56
|
tools: list[ToolInfo] | None = None,
|
|
57
57
|
tool_choice: Literal["auto", "required"] | None = None,
|
|
58
|
-
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
58
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
59
59
|
temperature: float = 1.0,
|
|
60
60
|
logprobs: bool = False,
|
|
61
61
|
top_logprobs: int | None = None,
|
|
@@ -86,7 +86,7 @@ class LLMCache:
|
|
|
86
86
|
*,
|
|
87
87
|
tools: list[ToolInfo] | None = None,
|
|
88
88
|
tool_choice: Literal["auto", "required"] | None = None,
|
|
89
|
-
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
89
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
90
90
|
temperature: float = 1.0,
|
|
91
91
|
logprobs: bool = False,
|
|
92
92
|
top_logprobs: int | None = None,
|
|
@@ -121,7 +121,7 @@ class LLMCache:
|
|
|
121
121
|
*,
|
|
122
122
|
tools: list[ToolInfo] | None = None,
|
|
123
123
|
tool_choice: Literal["auto", "required"] | None = None,
|
|
124
|
-
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
124
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
125
125
|
temperature: float = 1.0,
|
|
126
126
|
logprobs: bool = False,
|
|
127
127
|
top_logprobs: int | None = None,
|
|
@@ -154,7 +154,7 @@ class LLMCache:
|
|
|
154
154
|
*,
|
|
155
155
|
tools: list[ToolInfo] | None = None,
|
|
156
156
|
tool_choice: Literal["auto", "required"] | None = None,
|
|
157
|
-
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
157
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
158
158
|
temperature: float = 1.0,
|
|
159
159
|
logprobs: bool = False,
|
|
160
160
|
top_logprobs: int | None = None,
|
|
@@ -1,17 +1,8 @@
|
|
|
1
|
-
|
|
2
|
-
At some point we'll want to do a refactor to support different types of provider/key swapping
|
|
3
|
-
due to different scenarios. However, this'll probably be a breaking change, which is why I'm
|
|
4
|
-
not doing it now.
|
|
5
|
-
|
|
6
|
-
- mengk
|
|
7
|
-
"""
|
|
8
|
-
|
|
1
|
+
import time
|
|
9
2
|
import traceback
|
|
10
|
-
from contextlib import nullcontext
|
|
11
3
|
from functools import partial
|
|
12
4
|
from typing import (
|
|
13
5
|
Any,
|
|
14
|
-
AsyncContextManager,
|
|
15
6
|
Literal,
|
|
16
7
|
Protocol,
|
|
17
8
|
Sequence,
|
|
@@ -20,6 +11,7 @@ from typing import (
|
|
|
20
11
|
)
|
|
21
12
|
|
|
22
13
|
import anyio
|
|
14
|
+
from anyio import Lock, Semaphore
|
|
23
15
|
from anyio.abc import TaskGroup
|
|
24
16
|
from tqdm.auto import tqdm
|
|
25
17
|
|
|
@@ -44,10 +36,12 @@ from docent._llm_util.providers.provider_registry import (
|
|
|
44
36
|
from docent._log_util import get_logger
|
|
45
37
|
from docent.data_models.chat import ChatMessage, ToolInfo, parse_chat_message
|
|
46
38
|
|
|
47
|
-
MAX_VALIDATION_ATTEMPTS = 3
|
|
48
|
-
|
|
49
39
|
logger = get_logger(__name__)
|
|
50
40
|
|
|
41
|
+
MAX_VALIDATION_ATTEMPTS = 3
|
|
42
|
+
DEFAULT_MAX_CONCURRENCY = 100
|
|
43
|
+
DEFAULT_SVC_MAX_CONCURRENCY = 100
|
|
44
|
+
|
|
51
45
|
|
|
52
46
|
@runtime_checkable
|
|
53
47
|
class MessageResolver(Protocol):
|
|
@@ -87,11 +81,11 @@ async def _parallelize_calls(
|
|
|
87
81
|
tool_choice: Literal["auto", "required"] | None,
|
|
88
82
|
max_new_tokens: int,
|
|
89
83
|
temperature: float,
|
|
90
|
-
reasoning_effort: Literal["low", "medium", "high"] | None,
|
|
84
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None,
|
|
91
85
|
logprobs: bool,
|
|
92
86
|
top_logprobs: int | None,
|
|
93
87
|
timeout: float,
|
|
94
|
-
semaphore:
|
|
88
|
+
semaphore: Semaphore,
|
|
95
89
|
# use_tqdm: bool,
|
|
96
90
|
cache: LLMCache | None = None,
|
|
97
91
|
):
|
|
@@ -122,17 +116,19 @@ async def _parallelize_calls(
|
|
|
122
116
|
# Save resolved messages to avoid multiple resolutions
|
|
123
117
|
resolved_messages: list[list[ChatMessage] | None] = [None] * len(inputs)
|
|
124
118
|
|
|
125
|
-
|
|
119
|
+
# Not sure why the cast is necessary for the type checker
|
|
120
|
+
cancelled_due_to_usage_limit: bool = cast(bool, False)
|
|
126
121
|
|
|
127
122
|
async def _limited_task(i: int, cur_input: MessagesInput, tg: TaskGroup):
|
|
128
123
|
nonlocal responses, pbar, resolved_messages, cancelled_due_to_usage_limit
|
|
129
124
|
|
|
130
|
-
async with semaphore
|
|
125
|
+
async with semaphore:
|
|
131
126
|
messages = _resolve_messages_input(cur_input)
|
|
132
127
|
resolved_messages[i] = messages
|
|
133
128
|
|
|
134
129
|
retry_count = 0
|
|
135
130
|
result = None
|
|
131
|
+
call_started_at: float | None = None
|
|
136
132
|
|
|
137
133
|
# Check if there's a cached result
|
|
138
134
|
cached_result = (
|
|
@@ -154,6 +150,7 @@ async def _parallelize_calls(
|
|
|
154
150
|
if streaming_callback is not None:
|
|
155
151
|
await streaming_callback(i, result)
|
|
156
152
|
else:
|
|
153
|
+
call_started_at = time.perf_counter()
|
|
157
154
|
while retry_count < MAX_VALIDATION_ATTEMPTS:
|
|
158
155
|
try:
|
|
159
156
|
if streaming_callback is None:
|
|
@@ -187,7 +184,7 @@ async def _parallelize_calls(
|
|
|
187
184
|
errors=[e],
|
|
188
185
|
)
|
|
189
186
|
break
|
|
190
|
-
except DocentUsageLimitException as
|
|
187
|
+
except DocentUsageLimitException as _:
|
|
191
188
|
result = LLMOutput(
|
|
192
189
|
model=model_name,
|
|
193
190
|
completions=[],
|
|
@@ -219,6 +216,10 @@ async def _parallelize_calls(
|
|
|
219
216
|
)
|
|
220
217
|
break
|
|
221
218
|
|
|
219
|
+
# Only store the elapsed time if we didn't hit the cache and the call was successful
|
|
220
|
+
if cached_result is None and result is not None and call_started_at is not None:
|
|
221
|
+
result.duration = time.perf_counter() - call_started_at
|
|
222
|
+
|
|
222
223
|
# Always call completion callback with final result (success or error)
|
|
223
224
|
if completion_callback and result is not None:
|
|
224
225
|
try:
|
|
@@ -273,21 +274,24 @@ async def _parallelize_calls(
|
|
|
273
274
|
# Cache what we have so far if something got cancelled
|
|
274
275
|
except anyio.get_cancelled_exc_class():
|
|
275
276
|
num_cached = _cache_responses()
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
277
|
+
if num_cached:
|
|
278
|
+
logger.info(
|
|
279
|
+
f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls, but cached {num_cached} completed responses"
|
|
280
|
+
)
|
|
280
281
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
282
|
+
# If the task was cancelled due to usage limit, set the response to a usage limit exception
|
|
283
|
+
if cancelled_due_to_usage_limit:
|
|
284
|
+
for i, response in enumerate(responses):
|
|
285
|
+
if response is None:
|
|
286
|
+
responses[i] = LLMOutput(
|
|
287
|
+
model=model_name,
|
|
288
|
+
completions=[],
|
|
289
|
+
errors=[DocentUsageLimitException()],
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
response.errors.append(DocentUsageLimitException())
|
|
293
|
+
|
|
294
|
+
raise
|
|
291
295
|
|
|
292
296
|
# Cache results if available
|
|
293
297
|
_cache_responses()
|
|
@@ -300,51 +304,88 @@ async def _parallelize_calls(
|
|
|
300
304
|
return cast(list[LLMOutput], responses)
|
|
301
305
|
|
|
302
306
|
|
|
303
|
-
class
|
|
304
|
-
def __init__(
|
|
305
|
-
self
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
):
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
307
|
+
class BaseLLMService:
|
|
308
|
+
def __init__(self, max_concurrency: int = DEFAULT_SVC_MAX_CONCURRENCY):
|
|
309
|
+
self._semaphore = Semaphore(max_concurrency)
|
|
310
|
+
self._client_cache: dict[tuple[str, str | None], Any] = {} # (provider, api_key) -> client
|
|
311
|
+
self._client_cache_lock = Lock()
|
|
312
|
+
|
|
313
|
+
async def _get_cached_client(self, provider: str, override_key: str | None) -> Any:
|
|
314
|
+
"""Return a cached client for the provider/api-key tuple, creating one if needed."""
|
|
315
|
+
cache_key = (provider, override_key)
|
|
316
|
+
async with self._client_cache_lock:
|
|
317
|
+
cached = self._client_cache.get(cache_key)
|
|
318
|
+
if cached is not None:
|
|
319
|
+
return cached
|
|
316
320
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
321
|
+
client_factory = PROVIDERS[provider]["async_client_getter"]
|
|
322
|
+
new_client = client_factory(override_key)
|
|
323
|
+
self._client_cache[cache_key] = new_client
|
|
324
|
+
return new_client
|
|
320
325
|
|
|
321
326
|
async def get_completions(
|
|
322
327
|
self,
|
|
328
|
+
*,
|
|
323
329
|
inputs: list[MessagesInput],
|
|
330
|
+
model_options: list[ModelOption],
|
|
324
331
|
tools: list[ToolInfo] | None = None,
|
|
325
332
|
tool_choice: Literal["auto", "required"] | None = None,
|
|
326
|
-
max_new_tokens: int =
|
|
333
|
+
max_new_tokens: int = 1024,
|
|
327
334
|
temperature: float = 1.0,
|
|
328
335
|
logprobs: bool = False,
|
|
329
336
|
top_logprobs: int | None = None,
|
|
330
|
-
|
|
331
|
-
timeout: float = 5.0,
|
|
337
|
+
timeout: float = 120.0,
|
|
332
338
|
streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
333
339
|
validation_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
334
340
|
completion_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
341
|
+
use_cache: bool = False,
|
|
342
|
+
_api_key_overrides: dict[str, str] = dict(),
|
|
335
343
|
) -> list[LLMOutput]:
|
|
344
|
+
"""Request completions from a configured LLM provider."""
|
|
345
|
+
|
|
346
|
+
# We don't support logprobs for Anthropic yet
|
|
347
|
+
if logprobs:
|
|
348
|
+
for model_option in model_options:
|
|
349
|
+
if model_option.provider == "anthropic":
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Instantiate cache
|
|
355
|
+
# TODO(mengk): make this more robust, possibly move to a NoSQL database or something
|
|
356
|
+
try:
|
|
357
|
+
cache = LLMCache() if use_cache else None
|
|
358
|
+
except ValueError as e:
|
|
359
|
+
logger.warning(f"Disabling LLM cache due to init error: {e}")
|
|
360
|
+
cache = None
|
|
361
|
+
|
|
362
|
+
# Initialize pointer to which model we're using; used for model rotation after failures
|
|
363
|
+
current_model_option_index = 0
|
|
364
|
+
|
|
365
|
+
def _rotate_model_option() -> ModelOption | None:
|
|
366
|
+
nonlocal current_model_option_index
|
|
367
|
+
|
|
368
|
+
current_model_option_index += 1
|
|
369
|
+
if current_model_option_index >= len(model_options):
|
|
370
|
+
logger.error("All model options are exhausted")
|
|
371
|
+
return None
|
|
372
|
+
|
|
373
|
+
new_model_option = model_options[current_model_option_index]
|
|
374
|
+
logger.warning(f"Switched to next model {new_model_option.model_name}")
|
|
375
|
+
return new_model_option
|
|
376
|
+
|
|
336
377
|
while True:
|
|
337
378
|
# Parse the current model option
|
|
338
|
-
cur_option =
|
|
379
|
+
cur_option = model_options[current_model_option_index]
|
|
339
380
|
provider, model_name, reasoning_effort = (
|
|
340
381
|
cur_option.provider,
|
|
341
382
|
cur_option.model_name,
|
|
342
383
|
cur_option.reasoning_effort,
|
|
343
384
|
)
|
|
344
385
|
|
|
345
|
-
override_key =
|
|
386
|
+
override_key = _api_key_overrides.get(provider)
|
|
346
387
|
|
|
347
|
-
client =
|
|
388
|
+
client = await self._get_cached_client(provider, override_key)
|
|
348
389
|
single_output_getter = PROVIDERS[provider]["single_output_getter"]
|
|
349
390
|
single_streaming_output_getter = PROVIDERS[provider]["single_streaming_output_getter"]
|
|
350
391
|
|
|
@@ -369,10 +410,8 @@ class LLMManager:
|
|
|
369
410
|
logprobs=logprobs,
|
|
370
411
|
top_logprobs=top_logprobs,
|
|
371
412
|
timeout=timeout,
|
|
372
|
-
semaphore=
|
|
373
|
-
|
|
374
|
-
),
|
|
375
|
-
cache=self.cache,
|
|
413
|
+
semaphore=self._semaphore,
|
|
414
|
+
cache=cache,
|
|
376
415
|
)
|
|
377
416
|
assert len(outputs) == len(inputs), "Number of outputs must match number of messages"
|
|
378
417
|
|
|
@@ -388,23 +427,13 @@ class LLMManager:
|
|
|
388
427
|
)
|
|
389
428
|
if num_rotation_errors > 0:
|
|
390
429
|
logger.warning(f"{model_name}: {num_rotation_errors} API errors")
|
|
391
|
-
if not
|
|
430
|
+
if not _rotate_model_option():
|
|
392
431
|
break
|
|
393
432
|
else:
|
|
394
433
|
break
|
|
395
434
|
|
|
396
435
|
return outputs
|
|
397
436
|
|
|
398
|
-
def _rotate_model_option(self) -> ModelOption | None:
|
|
399
|
-
self.current_model_option_index += 1
|
|
400
|
-
if self.current_model_option_index >= len(self.model_options):
|
|
401
|
-
logger.error("All model options are exhausted")
|
|
402
|
-
return None
|
|
403
|
-
|
|
404
|
-
new_model_option = self.model_options[self.current_model_option_index]
|
|
405
|
-
logger.warning(f"Switched to next model {new_model_option.model_name}")
|
|
406
|
-
return new_model_option
|
|
407
|
-
|
|
408
437
|
|
|
409
438
|
async def get_llm_completions_async(
|
|
410
439
|
inputs: list[MessagesInput],
|
|
@@ -415,40 +444,29 @@ async def get_llm_completions_async(
|
|
|
415
444
|
temperature: float = 1.0,
|
|
416
445
|
logprobs: bool = False,
|
|
417
446
|
top_logprobs: int | None = None,
|
|
418
|
-
max_concurrency: int = 100,
|
|
419
447
|
timeout: float = 120.0,
|
|
420
448
|
streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
421
449
|
validation_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
422
450
|
completion_callback: AsyncLLMOutputStreamingCallback | None = None,
|
|
423
451
|
use_cache: bool = False,
|
|
424
|
-
|
|
452
|
+
_api_key_overrides: dict[str, str] = dict(),
|
|
425
453
|
) -> list[LLMOutput]:
|
|
426
|
-
|
|
427
|
-
if logprobs:
|
|
428
|
-
for model_option in model_options:
|
|
429
|
-
if model_option.provider == "anthropic":
|
|
430
|
-
raise ValueError(
|
|
431
|
-
f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
|
|
432
|
-
)
|
|
454
|
+
"""Convenience method for backward compatibility"""
|
|
433
455
|
|
|
434
|
-
|
|
435
|
-
|
|
456
|
+
svc = BaseLLMService()
|
|
457
|
+
return await svc.get_completions(
|
|
458
|
+
inputs=inputs,
|
|
436
459
|
model_options=model_options,
|
|
437
|
-
api_key_overrides=api_key_overrides,
|
|
438
|
-
use_cache=use_cache,
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
return await llm_manager.get_completions(
|
|
442
|
-
inputs,
|
|
443
460
|
tools=tools,
|
|
444
461
|
tool_choice=tool_choice,
|
|
445
462
|
max_new_tokens=max_new_tokens,
|
|
446
463
|
temperature=temperature,
|
|
447
464
|
logprobs=logprobs,
|
|
448
465
|
top_logprobs=top_logprobs,
|
|
449
|
-
max_concurrency=max_concurrency,
|
|
450
466
|
timeout=timeout,
|
|
451
467
|
streaming_callback=streaming_callback,
|
|
452
468
|
validation_callback=validation_callback,
|
|
453
469
|
completion_callback=completion_callback,
|
|
470
|
+
use_cache=use_cache,
|
|
471
|
+
_api_key_overrides=_api_key_overrides,
|
|
454
472
|
)
|
{docent_python-0.1.22a0 → docent_python-0.1.24a0}/docent/_llm_util/providers/preference_types.py
RENAMED
|
@@ -22,7 +22,7 @@ class ModelOption(BaseModel):
|
|
|
22
22
|
|
|
23
23
|
provider: str
|
|
24
24
|
model_name: str
|
|
25
|
-
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
25
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class ModelOptionWithContext(BaseModel):
|
|
@@ -39,7 +39,7 @@ class ModelOptionWithContext(BaseModel):
|
|
|
39
39
|
|
|
40
40
|
provider: str
|
|
41
41
|
model_name: str
|
|
42
|
-
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
42
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None
|
|
43
43
|
context_window: int
|
|
44
44
|
uses_byok: bool
|
|
45
45
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from docent.data_models.agent_run import AgentRun
|
|
2
2
|
from docent.data_models.citation import Citation
|
|
3
|
-
from docent.data_models.judge import
|
|
3
|
+
from docent.data_models.judge import Label
|
|
4
4
|
from docent.data_models.regex import RegexSnippet
|
|
5
5
|
from docent.data_models.transcript import Transcript, TranscriptGroup
|
|
6
6
|
|
|
7
7
|
__all__ = [
|
|
8
8
|
"AgentRun",
|
|
9
9
|
"Citation",
|
|
10
|
-
"
|
|
10
|
+
"Label",
|
|
11
11
|
"RegexSnippet",
|
|
12
12
|
"Transcript",
|
|
13
13
|
"TranscriptGroup",
|
|
@@ -6,11 +6,14 @@ from uuid import uuid4
|
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class
|
|
9
|
+
class Label(BaseModel):
|
|
10
10
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
11
|
+
|
|
12
|
+
label_set_id: str
|
|
13
|
+
|
|
14
|
+
label_value: dict[str, Any]
|
|
15
|
+
|
|
11
16
|
agent_run_id: str
|
|
12
|
-
rubric_id: str
|
|
13
|
-
label: dict[str, Any]
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
__all__ = ["
|
|
19
|
+
__all__ = ["Label"]
|
|
@@ -3,6 +3,7 @@ from docent.judges.types import (
|
|
|
3
3
|
JudgeResult,
|
|
4
4
|
JudgeResultCompletionCallback,
|
|
5
5
|
JudgeResultWithCitations,
|
|
6
|
+
JudgeVariant,
|
|
6
7
|
ResultType,
|
|
7
8
|
Rubric,
|
|
8
9
|
)
|
|
@@ -18,4 +19,5 @@ __all__ = [
|
|
|
18
19
|
"JudgeResultWithCitations",
|
|
19
20
|
"JudgeResultCompletionCallback",
|
|
20
21
|
"ResultType",
|
|
22
|
+
"JudgeVariant",
|
|
21
23
|
]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import anyio
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from pydantic_core import to_jsonable_python
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
|
|
10
|
+
from docent._log_util import get_logger
|
|
11
|
+
from docent.data_models.agent_run import AgentRun
|
|
12
|
+
from docent.judges.impl import BaseJudge
|
|
13
|
+
from docent.judges.util.voting import JudgeOutputDistribution
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MultiReflectRollouts(BaseModel):
|
|
19
|
+
"""Object is associated with a single agent run"""
|
|
20
|
+
|
|
21
|
+
agent_run_id: str
|
|
22
|
+
|
|
23
|
+
first_step_rollouts: list[dict[str, Any]]
|
|
24
|
+
first_step_rollout_metadata: list[dict[str, Any] | None]
|
|
25
|
+
# Each index in second_step_rollouts corresponds to an index in first_step_combinations
|
|
26
|
+
# Step 2 rollouts are computed by passing each step 1 combo into the judge several times
|
|
27
|
+
first_step_combinations: list[list[dict[str, Any]]] | None = None
|
|
28
|
+
second_step_rollouts: list[list[dict[str, Any]]] | None = None
|
|
29
|
+
second_step_rollout_metadata: list[list[dict[str, Any] | None]] | None = None
|
|
30
|
+
|
|
31
|
+
distributions: dict[str, JudgeOutputDistribution]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def collect_judge_pvs(
|
|
35
|
+
judge: BaseJudge,
|
|
36
|
+
agent_runs: list[AgentRun],
|
|
37
|
+
*,
|
|
38
|
+
results_path: Path,
|
|
39
|
+
estimate_output_distrs_kwargs: dict[str, Any],
|
|
40
|
+
):
|
|
41
|
+
if results_path.exists():
|
|
42
|
+
raise FileExistsError(f"Results path already exists: {results_path}")
|
|
43
|
+
results_path.parent.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
results = dict[str, MultiReflectRollouts]()
|
|
46
|
+
persist_lock = anyio.Lock()
|
|
47
|
+
pbar = tqdm(total=len(agent_runs), desc="Processing agent runs")
|
|
48
|
+
|
|
49
|
+
async def _persist():
|
|
50
|
+
async with persist_lock:
|
|
51
|
+
with open(str(results_path), "w") as f:
|
|
52
|
+
json.dump(to_jsonable_python(results), f, indent=2)
|
|
53
|
+
|
|
54
|
+
async def _execute_for_agent_run(agent_run: AgentRun):
|
|
55
|
+
result = await judge.estimate_output_distrs(agent_run, **estimate_output_distrs_kwargs)
|
|
56
|
+
if result is None:
|
|
57
|
+
pbar.update(1)
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
distrs, metadata = result
|
|
61
|
+
results[agent_run.id] = MultiReflectRollouts.model_validate(
|
|
62
|
+
{
|
|
63
|
+
"agent_run_id": agent_run.id,
|
|
64
|
+
"distributions": distrs,
|
|
65
|
+
**metadata,
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
await _persist()
|
|
69
|
+
pbar.update(1)
|
|
70
|
+
|
|
71
|
+
async with anyio.create_task_group() as tg_outer:
|
|
72
|
+
for agent_run in agent_runs:
|
|
73
|
+
tg_outer.start_soon(_execute_for_agent_run, agent_run)
|
|
74
|
+
|
|
75
|
+
pbar.close()
|
|
76
|
+
|
|
77
|
+
return results
|