docent-python 0.1.19a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +130 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/__init__.py +2 -2
- docent/data_models/agent_run.py +1 -0
- docent/data_models/judge.py +7 -4
- docent/data_models/transcript.py +2 -0
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +311 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +87 -0
- docent/judges/util/voting.py +139 -0
- docent/sdk/client.py +181 -44
- docent/trace.py +362 -44
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
- docent_python-0.1.27a0.dist-info/RECORD +59 -0
- docent_python-0.1.19a0.dist-info/RECORD +0 -32
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
docent/judges/impl.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import re
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any, Sequence
|
|
6
|
+
|
|
7
|
+
import anyio
|
|
8
|
+
import yaml
|
|
9
|
+
from pydantic_core import to_jsonable_python
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from docent._llm_util.data_models.exceptions import ValidationFailedException
|
|
13
|
+
from docent._llm_util.data_models.llm_output import LLMOutput
|
|
14
|
+
from docent._llm_util.llm_svc import BaseLLMService
|
|
15
|
+
from docent._log_util import get_logger
|
|
16
|
+
from docent.data_models.agent_run import AgentRun
|
|
17
|
+
from docent.data_models.chat.message import (
|
|
18
|
+
AssistantMessage,
|
|
19
|
+
ChatMessage,
|
|
20
|
+
ToolMessage,
|
|
21
|
+
UserMessage,
|
|
22
|
+
)
|
|
23
|
+
from docent.data_models.chat.tool import ToolInfo
|
|
24
|
+
from docent.judges.types import JudgeResult, JudgeVariant, ResultType, Rubric
|
|
25
|
+
from docent.judges.util.parse_output import parse_and_validate_output_str
|
|
26
|
+
from docent.judges.util.voting import (
|
|
27
|
+
JudgeOutputDistribution,
|
|
28
|
+
compute_output_distributions,
|
|
29
|
+
find_modal_result,
|
|
30
|
+
get_agreement_keys,
|
|
31
|
+
)
|
|
32
|
+
from docent.trace import agent_run_context, agent_run_metadata
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BaseJudge(ABC):
|
|
38
|
+
def __init__(
|
|
39
|
+
self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
|
|
40
|
+
):
|
|
41
|
+
self.cfg = cfg
|
|
42
|
+
self.llm_svc = llm_svc
|
|
43
|
+
self.docent_collection_id = docent_collection_id
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
47
|
+
"""Returns None if all rollouts failed to produce a valid output."""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def estimate_output_distrs(
|
|
51
|
+
self, agent_run: AgentRun, **kwargs: Any
|
|
52
|
+
) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
|
|
53
|
+
"""Estimate the output distribution of each output key."""
|
|
54
|
+
|
|
55
|
+
def _get_validation_callback(self, agent_run: AgentRun):
|
|
56
|
+
async def _validation_callback(batch_index: int, llm_output: LLMOutput):
|
|
57
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
58
|
+
llm_output.first_text or "", agent_run
|
|
59
|
+
)
|
|
60
|
+
if validated_output is None:
|
|
61
|
+
raise ValidationFailedException(
|
|
62
|
+
"Validation failed", failed_output=llm_output.first_text
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return _validation_callback
|
|
66
|
+
|
|
67
|
+
async def one_rollout(
|
|
68
|
+
self, agent_run: AgentRun
|
|
69
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
70
|
+
with agent_run_context() if self.docent_collection_id is not None else nullcontext():
|
|
71
|
+
if self.cfg.rollout_type == "single_turn":
|
|
72
|
+
output, metadata = await self.one_single_turn_rollout(agent_run)
|
|
73
|
+
elif self.cfg.rollout_type == "multi_turn":
|
|
74
|
+
output, metadata = await self.one_multi_turn_rollout(
|
|
75
|
+
agent_run, max_turns=10, max_steps_per_turn=5
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Invalid rollout type: {self.cfg.rollout_type}")
|
|
79
|
+
|
|
80
|
+
if self.docent_collection_id is not None:
|
|
81
|
+
agent_run_metadata(
|
|
82
|
+
{
|
|
83
|
+
"agent_run_id": agent_run.id,
|
|
84
|
+
"judge_output": output,
|
|
85
|
+
"judge_rollout_metadata": to_jsonable_python(metadata),
|
|
86
|
+
}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return output, metadata
|
|
90
|
+
|
|
91
|
+
def _validate_first_response_tag_or_entire_output(
|
|
92
|
+
self, output_str: str, agent_run: AgentRun
|
|
93
|
+
) -> dict[str, Any] | None:
|
|
94
|
+
"""Validate the first <response> tag in the output string.
|
|
95
|
+
For backward compatibility, also try to validate the entire output as JSON, for
|
|
96
|
+
old system prompts that don't ask for <response> tags.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
output_str: The output string to validate
|
|
100
|
+
agent_run: The agent run to validate against
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The validated output if successful, None otherwise
|
|
104
|
+
"""
|
|
105
|
+
response_matches = re.findall(r"<response>(.*?)</response>", output_str, re.DOTALL)
|
|
106
|
+
|
|
107
|
+
# Try to validate any match; take the first
|
|
108
|
+
for response_text in response_matches:
|
|
109
|
+
try:
|
|
110
|
+
validated_output = parse_and_validate_output_str(
|
|
111
|
+
response_text, self.cfg.output_schema, agent_run
|
|
112
|
+
)
|
|
113
|
+
return validated_output
|
|
114
|
+
except ValidationFailedException:
|
|
115
|
+
continue # Try the next match if validation fails
|
|
116
|
+
|
|
117
|
+
# Try to validate the entire output as JSON
|
|
118
|
+
# But only if the output _didn't_ contain a <response>...</response> tag
|
|
119
|
+
if not response_matches:
|
|
120
|
+
try:
|
|
121
|
+
validated_output = parse_and_validate_output_str(
|
|
122
|
+
output_str, self.cfg.output_schema, agent_run
|
|
123
|
+
)
|
|
124
|
+
return validated_output
|
|
125
|
+
except ValidationFailedException:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
########################
|
|
131
|
+
# Single turn rollouts #
|
|
132
|
+
########################
|
|
133
|
+
|
|
134
|
+
async def one_single_turn_rollout(
|
|
135
|
+
self, agent_run: AgentRun
|
|
136
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
137
|
+
prompt = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
|
|
138
|
+
outputs = await self.llm_svc.get_completions(
|
|
139
|
+
inputs=[prompt],
|
|
140
|
+
model_options=[self.cfg.judge_model],
|
|
141
|
+
max_new_tokens=16384,
|
|
142
|
+
timeout=180.0,
|
|
143
|
+
use_cache=False,
|
|
144
|
+
validation_callback=self._get_validation_callback(agent_run),
|
|
145
|
+
)
|
|
146
|
+
output_str = outputs[0].first_text
|
|
147
|
+
|
|
148
|
+
# Extract all <response>...</response> tags from the current message
|
|
149
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
150
|
+
output_str or "", agent_run
|
|
151
|
+
)
|
|
152
|
+
if validated_output is not None:
|
|
153
|
+
return validated_output, {"full_output": output_str}
|
|
154
|
+
else:
|
|
155
|
+
return None, None
|
|
156
|
+
|
|
157
|
+
#######################
|
|
158
|
+
# Multi-turn rollouts #
|
|
159
|
+
#######################
|
|
160
|
+
|
|
161
|
+
async def one_multi_turn_rollout(
|
|
162
|
+
self, agent_run: AgentRun, max_turns: int, max_steps_per_turn: int
|
|
163
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
164
|
+
msgs = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
|
|
165
|
+
for _ in range(max_turns):
|
|
166
|
+
msgs = await self.agent_one_turn(msgs, max_steps_per_turn=max_steps_per_turn)
|
|
167
|
+
|
|
168
|
+
last_msg_content = msgs[-1].text if msgs else None
|
|
169
|
+
# Extract all <response>...</response> tags from the current message
|
|
170
|
+
# Return if we find a valid response; otherwise, continue
|
|
171
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
172
|
+
last_msg_content or "", agent_run
|
|
173
|
+
)
|
|
174
|
+
if validated_output is not None:
|
|
175
|
+
# When returning, strip out the system message, which duplicates the agent run
|
|
176
|
+
# content many times.
|
|
177
|
+
return validated_output, {"rollout_messages": msgs[1:]}
|
|
178
|
+
|
|
179
|
+
# No <response>...</response> tags with valid JSON,so return None
|
|
180
|
+
return None, None
|
|
181
|
+
|
|
182
|
+
async def agent_one_turn(self, init_msgs: Sequence[ChatMessage], max_steps_per_turn: int):
|
|
183
|
+
"""Given a list of messages, run one turn of the agent.
|
|
184
|
+
The agent may invoke tools, so we loop until there are no more to handle.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
msgs = list(init_msgs) # Shallow copy is fine
|
|
188
|
+
for _ in range(max_steps_per_turn):
|
|
189
|
+
last_msg = msgs[-1]
|
|
190
|
+
if last_msg.role == "system" or last_msg.role == "user" or last_msg.role == "tool":
|
|
191
|
+
outputs = await self.llm_svc.get_completions(
|
|
192
|
+
inputs=[msgs],
|
|
193
|
+
model_options=[self.cfg.judge_model],
|
|
194
|
+
tools=[
|
|
195
|
+
ToolInfo(
|
|
196
|
+
name="step_finished",
|
|
197
|
+
description="Call this tool to indicate that you have finished one step in the decision procedure",
|
|
198
|
+
)
|
|
199
|
+
],
|
|
200
|
+
max_new_tokens=16384,
|
|
201
|
+
timeout=180.0,
|
|
202
|
+
use_cache=False,
|
|
203
|
+
)
|
|
204
|
+
output = outputs[0].first
|
|
205
|
+
if output is None:
|
|
206
|
+
# FIXME(mengk): handle empty completion
|
|
207
|
+
raise ValueError("Empty completion in agent one turn")
|
|
208
|
+
new_assistant_msg = AssistantMessage(
|
|
209
|
+
content=output.text or "", tool_calls=output.tool_calls
|
|
210
|
+
)
|
|
211
|
+
msgs.append(new_assistant_msg)
|
|
212
|
+
elif last_msg.role == "assistant":
|
|
213
|
+
if last_msg.tool_calls is not None:
|
|
214
|
+
msgs.extend(
|
|
215
|
+
[
|
|
216
|
+
ToolMessage(
|
|
217
|
+
content="Step completed",
|
|
218
|
+
tool_call_id=tool_call.id,
|
|
219
|
+
)
|
|
220
|
+
for tool_call in last_msg.tool_calls
|
|
221
|
+
]
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
break # Terminate if there are no more tool calls to handle
|
|
225
|
+
else:
|
|
226
|
+
raise ValueError(f"Unknown message role: {last_msg.role}")
|
|
227
|
+
return msgs
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class SingleRolloutJudge(BaseJudge):
|
|
231
|
+
"""Rolls out the judge once."""
|
|
232
|
+
|
|
233
|
+
def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
|
|
234
|
+
super().__init__(cfg, llm_svc)
|
|
235
|
+
|
|
236
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
237
|
+
output, metadata = await self.one_rollout(agent_run)
|
|
238
|
+
if output is None:
|
|
239
|
+
return None
|
|
240
|
+
else:
|
|
241
|
+
return JudgeResult(
|
|
242
|
+
agent_run_id=agent_run.id,
|
|
243
|
+
rubric_id=self.cfg.id,
|
|
244
|
+
rubric_version=self.cfg.version,
|
|
245
|
+
output=output,
|
|
246
|
+
result_metadata={"rollout_metadata": metadata},
|
|
247
|
+
result_type=ResultType.DIRECT_RESULT,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class MajorityVotingJudge(BaseJudge):
|
|
252
|
+
"""Rolls out the judge multiple times, then uses majority voting to determine the final result."""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
|
|
256
|
+
):
|
|
257
|
+
super().__init__(cfg, llm_svc, docent_collection_id)
|
|
258
|
+
|
|
259
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
260
|
+
indep_results: list[dict[str, Any]] = []
|
|
261
|
+
indep_rollout_metadata: list[dict[str, Any] | None] = []
|
|
262
|
+
|
|
263
|
+
async def _execute():
|
|
264
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
265
|
+
if result is not None:
|
|
266
|
+
indep_results.append(result)
|
|
267
|
+
indep_rollout_metadata.append(metadata)
|
|
268
|
+
|
|
269
|
+
# Run rollouts concurrently
|
|
270
|
+
async with anyio.create_task_group() as tg:
|
|
271
|
+
for _ in range(self.cfg.n_rollouts_per_input):
|
|
272
|
+
tg.start_soon(_execute)
|
|
273
|
+
if not indep_results:
|
|
274
|
+
return None
|
|
275
|
+
|
|
276
|
+
# Get a list of the keys that we want to measure agreement on
|
|
277
|
+
agreement_keys = get_agreement_keys(self.cfg.output_schema)
|
|
278
|
+
|
|
279
|
+
# Find the result that best matches modal values
|
|
280
|
+
final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
|
|
281
|
+
indep_results, agreement_keys
|
|
282
|
+
)
|
|
283
|
+
final_output = indep_results[final_max_idx]
|
|
284
|
+
|
|
285
|
+
# Compute the distribution of the output across the agreement keys
|
|
286
|
+
final_output_distributions = compute_output_distributions(
|
|
287
|
+
indep_results, self.cfg.output_schema, agreement_keys
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return JudgeResult(
|
|
291
|
+
agent_run_id=agent_run.id,
|
|
292
|
+
rubric_id=self.cfg.id,
|
|
293
|
+
rubric_version=self.cfg.version,
|
|
294
|
+
output=final_output,
|
|
295
|
+
result_metadata={
|
|
296
|
+
"agt_keys": agreement_keys,
|
|
297
|
+
# Final measurements
|
|
298
|
+
"final_results": indep_results,
|
|
299
|
+
"final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
|
|
300
|
+
"final_max_idx": final_max_idx,
|
|
301
|
+
"final_output_distributions": final_output_distributions,
|
|
302
|
+
"final_rollout_metadata": indep_rollout_metadata,
|
|
303
|
+
},
|
|
304
|
+
result_type=ResultType.DIRECT_RESULT,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
async def estimate_output_distrs(
|
|
308
|
+
self, agent_run: AgentRun, *, n_initial_rollouts_to_sample: int, **kwargs: Any
|
|
309
|
+
) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
|
|
310
|
+
if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
"n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
indep_results: list[dict[str, Any]] = []
|
|
316
|
+
indep_rollout_metadata: list[dict[str, Any] | None] = []
|
|
317
|
+
pbar = tqdm(total=n_initial_rollouts_to_sample, desc="Independent rollouts", leave=False)
|
|
318
|
+
|
|
319
|
+
async def _execute():
|
|
320
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
321
|
+
if result is not None:
|
|
322
|
+
indep_results.append(result)
|
|
323
|
+
indep_rollout_metadata.append(metadata)
|
|
324
|
+
pbar.update(1)
|
|
325
|
+
|
|
326
|
+
# Run rollouts concurrently
|
|
327
|
+
async with anyio.create_task_group() as tg:
|
|
328
|
+
for _ in range(n_initial_rollouts_to_sample):
|
|
329
|
+
tg.start_soon(_execute)
|
|
330
|
+
|
|
331
|
+
pbar.close()
|
|
332
|
+
|
|
333
|
+
if not indep_results:
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
# Compute the probability vector for each agreement key
|
|
337
|
+
distributions = compute_output_distributions(
|
|
338
|
+
indep_results, self.cfg.output_schema, get_agreement_keys(self.cfg.output_schema)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return distributions, {
|
|
342
|
+
"first_step_rollouts": indep_results,
|
|
343
|
+
"first_step_rollout_metadata": indep_rollout_metadata,
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class MultiReflectionJudge(BaseJudge):
|
|
348
|
+
"""Rolls out the judge multiple times, then uses reflection to determine the final result."""
|
|
349
|
+
|
|
350
|
+
def __init__(
|
|
351
|
+
self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
|
|
352
|
+
):
|
|
353
|
+
super().__init__(cfg, llm_svc, docent_collection_id)
|
|
354
|
+
|
|
355
|
+
async def one_rollout_second_stage(
|
|
356
|
+
self, agent_run: AgentRun, first_stage_results: list[dict[str, Any]]
|
|
357
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
358
|
+
"""Reflect on the results of the first stage of rollouts.
|
|
359
|
+
TODO(mengk): this is only done in a single-turn way. We should generalize this to multi-turn.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
# Construct *single* reflection prompt
|
|
363
|
+
first_stage_results_text = "\n\n".join(
|
|
364
|
+
[
|
|
365
|
+
f"Rollout {j+1}:\n{yaml.dump(r, width=float('inf'))}"
|
|
366
|
+
for j, r in enumerate(first_stage_results)
|
|
367
|
+
]
|
|
368
|
+
)
|
|
369
|
+
reflection_instruction = (
|
|
370
|
+
f"We have sampled a judge {len(first_stage_results)} times to get {len(first_stage_results)} independent answers to the same rubric evaluation:\n"
|
|
371
|
+
f"{first_stage_results_text}\n\n"
|
|
372
|
+
f"Please reflect on these answers. Consider all the information and evidence presented. "
|
|
373
|
+
f"Return a final answer in the same JSON format as before."
|
|
374
|
+
)
|
|
375
|
+
reflection_prompt = [
|
|
376
|
+
# Original system prompt
|
|
377
|
+
{"role": "system", "content": self.cfg.materialize_system_prompt(agent_run)},
|
|
378
|
+
# Additional reflection instruction as a user message (kind of awkward)
|
|
379
|
+
{"role": "user", "content": reflection_instruction},
|
|
380
|
+
]
|
|
381
|
+
|
|
382
|
+
# Ask the judge to reflect on the others' results
|
|
383
|
+
outputs = await self.llm_svc.get_completions(
|
|
384
|
+
inputs=[reflection_prompt],
|
|
385
|
+
model_options=[self.cfg.judge_model],
|
|
386
|
+
max_new_tokens=16384,
|
|
387
|
+
timeout=180.0,
|
|
388
|
+
use_cache=False,
|
|
389
|
+
validation_callback=self._get_validation_callback(agent_run),
|
|
390
|
+
)
|
|
391
|
+
output_str = outputs[0].first_text
|
|
392
|
+
|
|
393
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
394
|
+
output_str or "", agent_run
|
|
395
|
+
)
|
|
396
|
+
if validated_output is not None:
|
|
397
|
+
return validated_output, None
|
|
398
|
+
else:
|
|
399
|
+
return None, None
|
|
400
|
+
|
|
401
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
402
|
+
rubric = self.cfg
|
|
403
|
+
|
|
404
|
+
indep_results: list[dict[str, Any]] = []
|
|
405
|
+
indep_rollout_metadata: list[dict[str, Any] | None] = []
|
|
406
|
+
|
|
407
|
+
async def _execute():
|
|
408
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
409
|
+
if result is not None:
|
|
410
|
+
indep_results.append(result)
|
|
411
|
+
indep_rollout_metadata.append(metadata)
|
|
412
|
+
|
|
413
|
+
# Stage 1: run rollouts concurrently
|
|
414
|
+
async with anyio.create_task_group() as tg:
|
|
415
|
+
for _ in range(self.cfg.n_rollouts_per_input):
|
|
416
|
+
tg.start_soon(_execute)
|
|
417
|
+
if not indep_results:
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
# Compute initial modes
|
|
421
|
+
agreement_keys = get_agreement_keys(rubric.output_schema)
|
|
422
|
+
indep_max_idx, indep_agt_key_modes_and_counts = find_modal_result(
|
|
423
|
+
indep_results, agreement_keys
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Stage 2: reflect on the results
|
|
427
|
+
# Shallow copies are fine
|
|
428
|
+
final_results = indep_results.copy()
|
|
429
|
+
final_rollout_metadata = indep_rollout_metadata.copy()
|
|
430
|
+
if len(indep_results) > 1:
|
|
431
|
+
candidate_final_results: list[dict[str, Any]] = []
|
|
432
|
+
candidate_final_rollout_metadata: list[dict[str, Any] | None] = []
|
|
433
|
+
|
|
434
|
+
async def _execute_second_stage():
|
|
435
|
+
result, metadata = await self.one_rollout_second_stage(agent_run, indep_results)
|
|
436
|
+
if result is not None:
|
|
437
|
+
candidate_final_results.append(result)
|
|
438
|
+
candidate_final_rollout_metadata.append(metadata)
|
|
439
|
+
|
|
440
|
+
async with anyio.create_task_group() as tg:
|
|
441
|
+
for _ in range(self.cfg.n_rollouts_per_input):
|
|
442
|
+
tg.start_soon(_execute_second_stage)
|
|
443
|
+
|
|
444
|
+
# Use reflected results if we got any, otherwise fall back to original results
|
|
445
|
+
if candidate_final_results:
|
|
446
|
+
final_results = candidate_final_results
|
|
447
|
+
final_rollout_metadata = candidate_final_rollout_metadata
|
|
448
|
+
else:
|
|
449
|
+
logger.warning("No reflected results found, falling back to original results")
|
|
450
|
+
|
|
451
|
+
final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
|
|
452
|
+
final_results, agreement_keys
|
|
453
|
+
)
|
|
454
|
+
return JudgeResult(
|
|
455
|
+
agent_run_id=agent_run.id,
|
|
456
|
+
rubric_id=rubric.id,
|
|
457
|
+
rubric_version=rubric.version,
|
|
458
|
+
output=final_results[final_max_idx],
|
|
459
|
+
result_metadata={
|
|
460
|
+
"agt_keys": agreement_keys,
|
|
461
|
+
# Final measurements
|
|
462
|
+
"final_results": final_results,
|
|
463
|
+
"final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
|
|
464
|
+
"final_max_idx": final_max_idx,
|
|
465
|
+
"final_rollout_metadata": final_rollout_metadata,
|
|
466
|
+
# Also include initial measurements
|
|
467
|
+
"indep_results": indep_results,
|
|
468
|
+
"indep_max_idx": indep_max_idx,
|
|
469
|
+
"indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
|
|
470
|
+
"indep_rollout_metadata": indep_rollout_metadata,
|
|
471
|
+
},
|
|
472
|
+
result_type=ResultType.DIRECT_RESULT,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
async def estimate_output_distrs(
|
|
476
|
+
self,
|
|
477
|
+
agent_run: AgentRun,
|
|
478
|
+
*,
|
|
479
|
+
n_initial_rollouts_to_sample: int,
|
|
480
|
+
n_combinations_to_sample: int,
|
|
481
|
+
n_reflection_rollouts_to_sample: int,
|
|
482
|
+
**kwargs: Any,
|
|
483
|
+
) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
|
|
484
|
+
if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
|
|
485
|
+
raise ValueError(
|
|
486
|
+
"n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
|
|
487
|
+
)
|
|
488
|
+
if self.cfg.n_rollouts_per_input > n_reflection_rollouts_to_sample:
|
|
489
|
+
raise ValueError(
|
|
490
|
+
"n_reflection_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
first_step_rollouts: list[dict[str, Any]] = []
|
|
494
|
+
first_step_rollout_metadata: list[dict[str, Any] | None] = []
|
|
495
|
+
first_step_combinations: list[list[dict[str, Any]]] = []
|
|
496
|
+
second_step_rollouts: list[list[dict[str, Any]]] = []
|
|
497
|
+
second_step_rollout_metadata: list[list[dict[str, Any] | None]] = []
|
|
498
|
+
|
|
499
|
+
##########
|
|
500
|
+
# Step 1 #
|
|
501
|
+
##########
|
|
502
|
+
|
|
503
|
+
pbar_first = tqdm(
|
|
504
|
+
total=n_initial_rollouts_to_sample, desc="Stage 1: Initial rollouts", leave=False
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
async def _execute_first_stage():
|
|
508
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
509
|
+
if result is not None:
|
|
510
|
+
first_step_rollouts.append(result)
|
|
511
|
+
first_step_rollout_metadata.append(metadata)
|
|
512
|
+
pbar_first.update(1)
|
|
513
|
+
|
|
514
|
+
# Collect rollouts of the first stage
|
|
515
|
+
async with anyio.create_task_group() as tg_first:
|
|
516
|
+
for _ in range(n_initial_rollouts_to_sample):
|
|
517
|
+
tg_first.start_soon(_execute_first_stage)
|
|
518
|
+
|
|
519
|
+
pbar_first.close()
|
|
520
|
+
|
|
521
|
+
if len(first_step_rollouts) < self.cfg.n_rollouts_per_input:
|
|
522
|
+
raise ValueError("Not enough first step rollouts to sample combinations")
|
|
523
|
+
|
|
524
|
+
# Sample random k-sized combinations of the first step rollouts
|
|
525
|
+
for _ in range(n_combinations_to_sample):
|
|
526
|
+
combination = random.sample(first_step_rollouts, self.cfg.n_rollouts_per_input)
|
|
527
|
+
first_step_combinations.append(combination)
|
|
528
|
+
second_step_rollouts.append([])
|
|
529
|
+
second_step_rollout_metadata.append([])
|
|
530
|
+
|
|
531
|
+
##########
|
|
532
|
+
# Step 2 #
|
|
533
|
+
##########
|
|
534
|
+
|
|
535
|
+
pbar_second = tqdm(
|
|
536
|
+
total=n_combinations_to_sample, desc="Stage 2: Combinations", leave=False
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
async with anyio.create_task_group() as tg_second:
|
|
540
|
+
|
|
541
|
+
async def _execute_second_stage(i: int, combination: list[dict[str, Any]]):
|
|
542
|
+
pbar_third = tqdm(
|
|
543
|
+
total=n_reflection_rollouts_to_sample,
|
|
544
|
+
desc=f"Stage 2: Combination {i+1}/{n_combinations_to_sample}",
|
|
545
|
+
leave=False,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
async def _execute_second_stage_inner():
|
|
549
|
+
result, metadata = await self.one_rollout_second_stage(agent_run, combination)
|
|
550
|
+
if result is not None:
|
|
551
|
+
second_step_rollouts[i].append(result)
|
|
552
|
+
second_step_rollout_metadata[i].append(metadata)
|
|
553
|
+
pbar_third.update(1)
|
|
554
|
+
|
|
555
|
+
async with anyio.create_task_group() as tg:
|
|
556
|
+
for _ in range(n_reflection_rollouts_to_sample):
|
|
557
|
+
tg.start_soon(_execute_second_stage_inner)
|
|
558
|
+
|
|
559
|
+
pbar_third.close()
|
|
560
|
+
pbar_second.update(1)
|
|
561
|
+
|
|
562
|
+
for i, combination in enumerate(first_step_combinations):
|
|
563
|
+
tg_second.start_soon(_execute_second_stage, i, combination)
|
|
564
|
+
|
|
565
|
+
pbar_second.close()
|
|
566
|
+
|
|
567
|
+
output_distributions = compute_output_distributions(
|
|
568
|
+
[sublist for el in second_step_rollouts for sublist in el],
|
|
569
|
+
self.cfg.output_schema,
|
|
570
|
+
get_agreement_keys(self.cfg.output_schema),
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
return output_distributions, {
|
|
574
|
+
"first_step_rollouts": first_step_rollouts,
|
|
575
|
+
"first_step_rollout_metadata": first_step_rollout_metadata,
|
|
576
|
+
"first_step_combinations": first_step_combinations,
|
|
577
|
+
"second_step_rollouts": second_step_rollouts,
|
|
578
|
+
"second_step_rollout_metadata": second_step_rollout_metadata,
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def build_judge(rubric: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None):
|
|
583
|
+
if rubric.judge_variant == JudgeVariant.MAJORITY:
|
|
584
|
+
return MajorityVotingJudge(rubric, llm_svc, docent_collection_id)
|
|
585
|
+
elif rubric.judge_variant == JudgeVariant.MULTI_REFLECT:
|
|
586
|
+
return MultiReflectionJudge(rubric, llm_svc, docent_collection_id)
|
|
587
|
+
raise ValueError(f"Invalid variant: {rubric.judge_variant}")
|