docent-python 0.1.22a0__py3-none-any.whl → 0.1.24a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/data_models/llm_output.py +3 -0
- docent/_llm_util/llm_cache.py +4 -4
- docent/_llm_util/{prod_llms.py → llm_svc.py} +104 -86
- docent/_llm_util/providers/preference_types.py +2 -2
- docent/data_models/__init__.py +2 -2
- docent/data_models/judge.py +7 -4
- docent/judges/__init__.py +2 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +476 -121
- docent/judges/runner.py +66 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +73 -2
- docent/judges/util/meta_schema.json +3 -1
- docent/judges/util/parse_output.py +8 -16
- docent/judges/util/voting.py +38 -13
- docent/sdk/client.py +90 -41
- docent/trace.py +35 -0
- {docent_python-0.1.22a0.dist-info → docent_python-0.1.24a0.dist-info}/METADATA +1 -1
- {docent_python-0.1.22a0.dist-info → docent_python-0.1.24a0.dist-info}/RECORD +21 -20
- docent/_llm_util/data_models/simple_svc.py +0 -79
- docent/trace_2.py +0 -1842
- {docent_python-0.1.22a0.dist-info → docent_python-0.1.24a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.22a0.dist-info → docent_python-0.1.24a0.dist-info}/licenses/LICENSE.md +0 -0
docent/judges/impl.py
CHANGED
|
@@ -1,71 +1,275 @@
|
|
|
1
|
-
import
|
|
1
|
+
import random
|
|
2
|
+
import re
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
|
-
from
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any, Sequence
|
|
4
6
|
|
|
7
|
+
import anyio
|
|
8
|
+
import yaml
|
|
9
|
+
from pydantic_core import to_jsonable_python
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from docent._llm_util.data_models.exceptions import ValidationFailedException
|
|
5
13
|
from docent._llm_util.data_models.llm_output import LLMOutput
|
|
6
|
-
from docent._llm_util.
|
|
14
|
+
from docent._llm_util.llm_svc import BaseLLMService
|
|
7
15
|
from docent._log_util import get_logger
|
|
8
16
|
from docent.data_models.agent_run import AgentRun
|
|
9
|
-
from docent.
|
|
10
|
-
|
|
17
|
+
from docent.data_models.chat.message import (
|
|
18
|
+
AssistantMessage,
|
|
19
|
+
ChatMessage,
|
|
20
|
+
ToolMessage,
|
|
21
|
+
UserMessage,
|
|
22
|
+
)
|
|
23
|
+
from docent.data_models.chat.tool import ToolInfo
|
|
24
|
+
from docent.judges.types import JudgeResult, JudgeVariant, ResultType, Rubric
|
|
25
|
+
from docent.judges.util.parse_output import parse_and_validate_output_str
|
|
11
26
|
from docent.judges.util.voting import (
|
|
12
|
-
|
|
27
|
+
JudgeOutputDistribution,
|
|
28
|
+
compute_output_distributions,
|
|
13
29
|
find_modal_result,
|
|
14
30
|
get_agreement_keys,
|
|
15
31
|
)
|
|
32
|
+
from docent.trace import agent_run_context, agent_run_metadata
|
|
16
33
|
|
|
17
34
|
logger = get_logger(__name__)
|
|
18
35
|
|
|
19
36
|
|
|
20
37
|
class BaseJudge(ABC):
|
|
21
|
-
def __init__(
|
|
38
|
+
def __init__(
|
|
39
|
+
self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
|
|
40
|
+
):
|
|
22
41
|
self.cfg = cfg
|
|
23
42
|
self.llm_svc = llm_svc
|
|
43
|
+
self.docent_collection_id = docent_collection_id
|
|
24
44
|
|
|
25
45
|
@abstractmethod
|
|
26
|
-
async def __call__(self, agent_run: AgentRun
|
|
46
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
27
47
|
"""Returns None if all rollouts failed to produce a valid output."""
|
|
28
48
|
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def estimate_output_distrs(
|
|
51
|
+
self, agent_run: AgentRun, **kwargs: Any
|
|
52
|
+
) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
|
|
53
|
+
"""Estimate the output distribution of each output key."""
|
|
29
54
|
|
|
30
|
-
|
|
31
|
-
"""Rolls out the judge multiple times, then uses majority voting to determine the final result."""
|
|
32
|
-
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
cfg: Rubric,
|
|
36
|
-
n_rollouts_per_input: int,
|
|
37
|
-
llm_svc: BaseLLMService = SimpleLLMService(),
|
|
38
|
-
):
|
|
39
|
-
super().__init__(cfg, llm_svc)
|
|
40
|
-
self.n_rollouts_per_input = n_rollouts_per_input
|
|
41
|
-
|
|
42
|
-
async def __call__(
|
|
43
|
-
self,
|
|
44
|
-
agent_run: AgentRun,
|
|
45
|
-
max_concurrency: int = 10,
|
|
46
|
-
) -> JudgeResult | None:
|
|
55
|
+
def _get_validation_callback(self, agent_run: AgentRun):
|
|
47
56
|
async def _validation_callback(batch_index: int, llm_output: LLMOutput):
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
57
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
58
|
+
llm_output.first_text or "", agent_run
|
|
59
|
+
)
|
|
60
|
+
if validated_output is None:
|
|
61
|
+
raise ValidationFailedException(
|
|
62
|
+
"Validation failed", failed_output=llm_output.first_text
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return _validation_callback
|
|
66
|
+
|
|
67
|
+
async def one_rollout(
|
|
68
|
+
self, agent_run: AgentRun
|
|
69
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
70
|
+
with agent_run_context() if self.docent_collection_id is not None else nullcontext():
|
|
71
|
+
if self.cfg.rollout_type == "single_turn":
|
|
72
|
+
output, metadata = await self.one_single_turn_rollout(agent_run)
|
|
73
|
+
elif self.cfg.rollout_type == "multi_turn":
|
|
74
|
+
output, metadata = await self.one_multi_turn_rollout(
|
|
75
|
+
agent_run, max_turns=10, max_steps_per_turn=5
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Invalid rollout type: {self.cfg.rollout_type}")
|
|
79
|
+
|
|
80
|
+
if self.docent_collection_id is not None:
|
|
81
|
+
agent_run_metadata(
|
|
82
|
+
{
|
|
83
|
+
"agent_run_id": agent_run.id,
|
|
84
|
+
"judge_output": output,
|
|
85
|
+
"judge_rollout_metadata": to_jsonable_python(metadata),
|
|
86
|
+
}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return output, metadata
|
|
90
|
+
|
|
91
|
+
def _validate_first_response_tag_or_entire_output(
|
|
92
|
+
self, output_str: str, agent_run: AgentRun
|
|
93
|
+
) -> dict[str, Any] | None:
|
|
94
|
+
"""Validate the first <response> tag in the output string.
|
|
95
|
+
For backward compatibility, also try to validate the entire output as JSON, for
|
|
96
|
+
old system prompts that don't ask for <response> tags.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
output_str: The output string to validate
|
|
100
|
+
agent_run: The agent run to validate against
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The validated output if successful, None otherwise
|
|
104
|
+
"""
|
|
105
|
+
response_matches = re.findall(r"<response>(.*?)</response>", output_str, re.DOTALL)
|
|
106
|
+
|
|
107
|
+
# Try to validate any match; take the first
|
|
108
|
+
for response_text in response_matches:
|
|
109
|
+
try:
|
|
110
|
+
validated_output = parse_and_validate_output_str(
|
|
111
|
+
response_text, self.cfg.output_schema, agent_run
|
|
112
|
+
)
|
|
113
|
+
return validated_output
|
|
114
|
+
except ValidationFailedException:
|
|
115
|
+
continue # Try the next match if validation fails
|
|
116
|
+
|
|
117
|
+
# Try to validate the entire output as JSON
|
|
118
|
+
# But only if the output _didn't_ contain a <response>...</response> tag
|
|
119
|
+
if not response_matches:
|
|
120
|
+
try:
|
|
121
|
+
validated_output = parse_and_validate_output_str(
|
|
122
|
+
output_str, self.cfg.output_schema, agent_run
|
|
123
|
+
)
|
|
124
|
+
return validated_output
|
|
125
|
+
except ValidationFailedException:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
########################
|
|
131
|
+
# Single turn rollouts #
|
|
132
|
+
########################
|
|
133
|
+
|
|
134
|
+
async def one_single_turn_rollout(
|
|
135
|
+
self, agent_run: AgentRun
|
|
136
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
137
|
+
prompt = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
|
|
51
138
|
outputs = await self.llm_svc.get_completions(
|
|
52
|
-
inputs=[prompt
|
|
139
|
+
inputs=[prompt],
|
|
53
140
|
model_options=[self.cfg.judge_model],
|
|
54
141
|
max_new_tokens=16384,
|
|
55
142
|
timeout=180.0,
|
|
56
143
|
use_cache=False,
|
|
57
|
-
validation_callback=
|
|
58
|
-
max_concurrency=max_concurrency,
|
|
144
|
+
validation_callback=self._get_validation_callback(agent_run),
|
|
59
145
|
)
|
|
146
|
+
output_str = outputs[0].first_text
|
|
147
|
+
|
|
148
|
+
# Extract all <response>...</response> tags from the current message
|
|
149
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
150
|
+
output_str or "", agent_run
|
|
151
|
+
)
|
|
152
|
+
if validated_output is not None:
|
|
153
|
+
return validated_output, {"full_output": output_str}
|
|
154
|
+
else:
|
|
155
|
+
return None, None
|
|
156
|
+
|
|
157
|
+
#######################
|
|
158
|
+
# Multi-turn rollouts #
|
|
159
|
+
#######################
|
|
160
|
+
|
|
161
|
+
async def one_multi_turn_rollout(
|
|
162
|
+
self, agent_run: AgentRun, max_turns: int, max_steps_per_turn: int
|
|
163
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
164
|
+
msgs = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
|
|
165
|
+
for _ in range(max_turns):
|
|
166
|
+
msgs = await self.agent_one_turn(msgs, max_steps_per_turn=max_steps_per_turn)
|
|
167
|
+
|
|
168
|
+
last_msg_content = msgs[-1].text if msgs else None
|
|
169
|
+
# Extract all <response>...</response> tags from the current message
|
|
170
|
+
# Return if we find a valid response; otherwise, continue
|
|
171
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
172
|
+
last_msg_content or "", agent_run
|
|
173
|
+
)
|
|
174
|
+
if validated_output is not None:
|
|
175
|
+
# When returning, strip out the system message, which duplicates the agent run
|
|
176
|
+
# content many times.
|
|
177
|
+
return validated_output, {"rollout_messages": msgs[1:]}
|
|
178
|
+
|
|
179
|
+
# No <response>...</response> tags with valid JSON,so return None
|
|
180
|
+
return None, None
|
|
181
|
+
|
|
182
|
+
async def agent_one_turn(self, init_msgs: Sequence[ChatMessage], max_steps_per_turn: int):
|
|
183
|
+
"""Given a list of messages, run one turn of the agent.
|
|
184
|
+
The agent may invoke tools, so we loop until there are no more to handle.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
msgs = list(init_msgs) # Shallow copy is fine
|
|
188
|
+
for _ in range(max_steps_per_turn):
|
|
189
|
+
last_msg = msgs[-1]
|
|
190
|
+
if last_msg.role == "system" or last_msg.role == "user" or last_msg.role == "tool":
|
|
191
|
+
outputs = await self.llm_svc.get_completions(
|
|
192
|
+
inputs=[msgs],
|
|
193
|
+
model_options=[self.cfg.judge_model],
|
|
194
|
+
tools=[
|
|
195
|
+
ToolInfo(
|
|
196
|
+
name="step_finished",
|
|
197
|
+
description="Call this tool to indicate that you have finished one step in the decision procedure",
|
|
198
|
+
)
|
|
199
|
+
],
|
|
200
|
+
max_new_tokens=16384,
|
|
201
|
+
timeout=180.0,
|
|
202
|
+
use_cache=False,
|
|
203
|
+
)
|
|
204
|
+
output = outputs[0].first
|
|
205
|
+
if output is None:
|
|
206
|
+
# FIXME(mengk): handle empty completion
|
|
207
|
+
raise ValueError("Empty completion in agent one turn")
|
|
208
|
+
new_assistant_msg = AssistantMessage(
|
|
209
|
+
content=output.text or "", tool_calls=output.tool_calls
|
|
210
|
+
)
|
|
211
|
+
msgs.append(new_assistant_msg)
|
|
212
|
+
elif last_msg.role == "assistant":
|
|
213
|
+
if last_msg.tool_calls is not None:
|
|
214
|
+
msgs.extend(
|
|
215
|
+
[
|
|
216
|
+
ToolMessage(
|
|
217
|
+
content="Step completed",
|
|
218
|
+
tool_call_id=tool_call.id,
|
|
219
|
+
)
|
|
220
|
+
for tool_call in last_msg.tool_calls
|
|
221
|
+
]
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
break # Terminate if there are no more tool calls to handle
|
|
225
|
+
else:
|
|
226
|
+
raise ValueError(f"Unknown message role: {last_msg.role}")
|
|
227
|
+
return msgs
|
|
60
228
|
|
|
61
|
-
# Process each rollout independently
|
|
62
|
-
indep_results: list[dict[str, Any]] = []
|
|
63
|
-
for output in outputs:
|
|
64
|
-
if validated_output := parse_and_validate_llm_output(
|
|
65
|
-
output, self.cfg.output_schema, agent_run
|
|
66
|
-
):
|
|
67
|
-
indep_results.append(validated_output)
|
|
68
229
|
|
|
230
|
+
class SingleRolloutJudge(BaseJudge):
|
|
231
|
+
"""Rolls out the judge once."""
|
|
232
|
+
|
|
233
|
+
def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
|
|
234
|
+
super().__init__(cfg, llm_svc)
|
|
235
|
+
|
|
236
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
237
|
+
output, metadata = await self.one_rollout(agent_run)
|
|
238
|
+
if output is None:
|
|
239
|
+
return None
|
|
240
|
+
else:
|
|
241
|
+
return JudgeResult(
|
|
242
|
+
agent_run_id=agent_run.id,
|
|
243
|
+
rubric_id=self.cfg.id,
|
|
244
|
+
rubric_version=self.cfg.version,
|
|
245
|
+
output=output,
|
|
246
|
+
result_metadata={"rollout_metadata": metadata},
|
|
247
|
+
result_type=ResultType.DIRECT_RESULT,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class MajorityVotingJudge(BaseJudge):
|
|
252
|
+
"""Rolls out the judge multiple times, then uses majority voting to determine the final result."""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
|
|
256
|
+
):
|
|
257
|
+
super().__init__(cfg, llm_svc, docent_collection_id)
|
|
258
|
+
|
|
259
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
260
|
+
indep_results: list[dict[str, Any]] = []
|
|
261
|
+
indep_rollout_metadata: list[dict[str, Any] | None] = []
|
|
262
|
+
|
|
263
|
+
async def _execute():
|
|
264
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
265
|
+
if result is not None:
|
|
266
|
+
indep_results.append(result)
|
|
267
|
+
indep_rollout_metadata.append(metadata)
|
|
268
|
+
|
|
269
|
+
# Run rollouts concurrently
|
|
270
|
+
async with anyio.create_task_group() as tg:
|
|
271
|
+
for _ in range(self.cfg.n_rollouts_per_input):
|
|
272
|
+
tg.start_soon(_execute)
|
|
69
273
|
if not indep_results:
|
|
70
274
|
return None
|
|
71
275
|
|
|
@@ -79,7 +283,7 @@ class MajorityVotingJudge(BaseJudge):
|
|
|
79
283
|
final_output = indep_results[final_max_idx]
|
|
80
284
|
|
|
81
285
|
# Compute the distribution of the output across the agreement keys
|
|
82
|
-
|
|
286
|
+
final_output_distributions = compute_output_distributions(
|
|
83
287
|
indep_results, self.cfg.output_schema, agreement_keys
|
|
84
288
|
)
|
|
85
289
|
|
|
@@ -94,54 +298,122 @@ class MajorityVotingJudge(BaseJudge):
|
|
|
94
298
|
"final_results": indep_results,
|
|
95
299
|
"final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
|
|
96
300
|
"final_max_idx": final_max_idx,
|
|
97
|
-
"
|
|
301
|
+
"final_output_distributions": final_output_distributions,
|
|
302
|
+
"final_rollout_metadata": indep_rollout_metadata,
|
|
98
303
|
},
|
|
99
304
|
result_type=ResultType.DIRECT_RESULT,
|
|
100
305
|
)
|
|
101
306
|
|
|
307
|
+
async def estimate_output_distrs(
|
|
308
|
+
self, agent_run: AgentRun, *, n_initial_rollouts_to_sample: int, **kwargs: Any
|
|
309
|
+
) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
|
|
310
|
+
if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
"n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
indep_results: list[dict[str, Any]] = []
|
|
316
|
+
indep_rollout_metadata: list[dict[str, Any] | None] = []
|
|
317
|
+
pbar = tqdm(total=n_initial_rollouts_to_sample, desc="Independent rollouts", leave=False)
|
|
318
|
+
|
|
319
|
+
async def _execute():
|
|
320
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
321
|
+
if result is not None:
|
|
322
|
+
indep_results.append(result)
|
|
323
|
+
indep_rollout_metadata.append(metadata)
|
|
324
|
+
pbar.update(1)
|
|
325
|
+
|
|
326
|
+
# Run rollouts concurrently
|
|
327
|
+
async with anyio.create_task_group() as tg:
|
|
328
|
+
for _ in range(n_initial_rollouts_to_sample):
|
|
329
|
+
tg.start_soon(_execute)
|
|
330
|
+
|
|
331
|
+
pbar.close()
|
|
332
|
+
|
|
333
|
+
if not indep_results:
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
# Compute the probability vector for each agreement key
|
|
337
|
+
distributions = compute_output_distributions(
|
|
338
|
+
indep_results, self.cfg.output_schema, get_agreement_keys(self.cfg.output_schema)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return distributions, {
|
|
342
|
+
"first_step_rollouts": indep_results,
|
|
343
|
+
"first_step_rollout_metadata": indep_rollout_metadata,
|
|
344
|
+
}
|
|
345
|
+
|
|
102
346
|
|
|
103
347
|
class MultiReflectionJudge(BaseJudge):
|
|
104
348
|
"""Rolls out the judge multiple times, then uses reflection to determine the final result."""
|
|
105
349
|
|
|
106
350
|
def __init__(
|
|
107
|
-
self,
|
|
108
|
-
cfg: Rubric,
|
|
109
|
-
n_rollouts_per_input: int,
|
|
110
|
-
llm_svc: BaseLLMService = SimpleLLMService(),
|
|
351
|
+
self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
|
|
111
352
|
):
|
|
112
|
-
super().__init__(cfg, llm_svc)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
353
|
+
super().__init__(cfg, llm_svc, docent_collection_id)
|
|
354
|
+
|
|
355
|
+
async def one_rollout_second_stage(
|
|
356
|
+
self, agent_run: AgentRun, first_stage_results: list[dict[str, Any]]
|
|
357
|
+
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
|
358
|
+
"""Reflect on the results of the first stage of rollouts.
|
|
359
|
+
TODO(mengk): this is only done in a single-turn way. We should generalize this to multi-turn.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
# Construct *single* reflection prompt
|
|
363
|
+
first_stage_results_text = "\n\n".join(
|
|
364
|
+
[
|
|
365
|
+
f"Rollout {j+1}:\n{yaml.dump(r, width=float('inf'))}"
|
|
366
|
+
for j, r in enumerate(first_stage_results)
|
|
367
|
+
]
|
|
368
|
+
)
|
|
369
|
+
reflection_instruction = (
|
|
370
|
+
f"We have sampled a judge {len(first_stage_results)} times to get {len(first_stage_results)} independent answers to the same rubric evaluation:\n"
|
|
371
|
+
f"{first_stage_results_text}\n\n"
|
|
372
|
+
f"Please reflect on these answers. Consider all the information and evidence presented. "
|
|
373
|
+
f"Return a final answer in the same JSON format as before."
|
|
374
|
+
)
|
|
375
|
+
reflection_prompt = [
|
|
376
|
+
# Original system prompt
|
|
377
|
+
{"role": "system", "content": self.cfg.materialize_system_prompt(agent_run)},
|
|
378
|
+
# Additional reflection instruction as a user message (kind of awkward)
|
|
379
|
+
{"role": "user", "content": reflection_instruction},
|
|
380
|
+
]
|
|
381
|
+
|
|
382
|
+
# Ask the judge to reflect on the others' results
|
|
127
383
|
outputs = await self.llm_svc.get_completions(
|
|
128
|
-
inputs=[
|
|
129
|
-
model_options=[
|
|
384
|
+
inputs=[reflection_prompt],
|
|
385
|
+
model_options=[self.cfg.judge_model],
|
|
130
386
|
max_new_tokens=16384,
|
|
131
387
|
timeout=180.0,
|
|
132
388
|
use_cache=False,
|
|
133
|
-
validation_callback=
|
|
134
|
-
max_concurrency=max_concurrency,
|
|
389
|
+
validation_callback=self._get_validation_callback(agent_run),
|
|
135
390
|
)
|
|
391
|
+
output_str = outputs[0].first_text
|
|
136
392
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
393
|
+
validated_output = self._validate_first_response_tag_or_entire_output(
|
|
394
|
+
output_str or "", agent_run
|
|
395
|
+
)
|
|
396
|
+
if validated_output is not None:
|
|
397
|
+
return validated_output, None
|
|
398
|
+
else:
|
|
399
|
+
return None, None
|
|
144
400
|
|
|
401
|
+
async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
|
|
402
|
+
rubric = self.cfg
|
|
403
|
+
|
|
404
|
+
indep_results: list[dict[str, Any]] = []
|
|
405
|
+
indep_rollout_metadata: list[dict[str, Any] | None] = []
|
|
406
|
+
|
|
407
|
+
async def _execute():
|
|
408
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
409
|
+
if result is not None:
|
|
410
|
+
indep_results.append(result)
|
|
411
|
+
indep_rollout_metadata.append(metadata)
|
|
412
|
+
|
|
413
|
+
# Stage 1: run rollouts concurrently
|
|
414
|
+
async with anyio.create_task_group() as tg:
|
|
415
|
+
for _ in range(self.cfg.n_rollouts_per_input):
|
|
416
|
+
tg.start_soon(_execute)
|
|
145
417
|
if not indep_results:
|
|
146
418
|
return None
|
|
147
419
|
|
|
@@ -151,61 +423,28 @@ class MultiReflectionJudge(BaseJudge):
|
|
|
151
423
|
indep_results, agreement_keys
|
|
152
424
|
)
|
|
153
425
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
other_results = [r for j, r in enumerate(indep_results) if j != cur_index]
|
|
159
|
-
|
|
160
|
-
# Create the reflection message
|
|
161
|
-
other_results_text = "\n\n".join(
|
|
162
|
-
[f"Answer {j+1}:\n{json.dumps(r, indent=2)}" for j, r in enumerate(other_results)]
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
reflection_instruction = (
|
|
166
|
-
f"Here are {len(other_results)} other independent answers to the same rubric evaluation:\n\n"
|
|
167
|
-
f"{other_results_text}\n\n"
|
|
168
|
-
f"Please reflect on these other answers and your own answer. "
|
|
169
|
-
f"Consider if any of them have identified important aspects you missed, or if there are disagreements that should be resolved. "
|
|
170
|
-
f"Then provide your final answer in the same JSON format as before."
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
# Construct the multi-message prompt
|
|
174
|
-
# 1. Original user message
|
|
175
|
-
# 2. Assistant message with the rollout's result
|
|
176
|
-
# 3. New user message asking for reflection
|
|
177
|
-
return [
|
|
178
|
-
*prompt, # Original user message(s)
|
|
179
|
-
{"role": "assistant", "content": json.dumps(result, indent=2)},
|
|
180
|
-
{"role": "user", "content": reflection_instruction},
|
|
181
|
-
]
|
|
182
|
-
|
|
183
|
-
final_results = indep_results.copy() # Shallow copy
|
|
426
|
+
# Stage 2: reflect on the results
|
|
427
|
+
# Shallow copies are fine
|
|
428
|
+
final_results = indep_results.copy()
|
|
429
|
+
final_rollout_metadata = indep_rollout_metadata.copy()
|
|
184
430
|
if len(indep_results) > 1:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
max_concurrency=max_concurrency,
|
|
194
|
-
)
|
|
431
|
+
candidate_final_results: list[dict[str, Any]] = []
|
|
432
|
+
candidate_final_rollout_metadata: list[dict[str, Any] | None] = []
|
|
433
|
+
|
|
434
|
+
async def _execute_second_stage():
|
|
435
|
+
result, metadata = await self.one_rollout_second_stage(agent_run, indep_results)
|
|
436
|
+
if result is not None:
|
|
437
|
+
candidate_final_results.append(result)
|
|
438
|
+
candidate_final_rollout_metadata.append(metadata)
|
|
195
439
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if output.first_text is None:
|
|
200
|
-
continue
|
|
201
|
-
if v_output := parse_and_validate_llm_output(
|
|
202
|
-
output, rubric.output_schema, agent_run
|
|
203
|
-
):
|
|
204
|
-
reflected_results.append(v_output)
|
|
440
|
+
async with anyio.create_task_group() as tg:
|
|
441
|
+
for _ in range(self.cfg.n_rollouts_per_input):
|
|
442
|
+
tg.start_soon(_execute_second_stage)
|
|
205
443
|
|
|
206
444
|
# Use reflected results if we got any, otherwise fall back to original results
|
|
207
|
-
if
|
|
208
|
-
final_results =
|
|
445
|
+
if candidate_final_results:
|
|
446
|
+
final_results = candidate_final_results
|
|
447
|
+
final_rollout_metadata = candidate_final_rollout_metadata
|
|
209
448
|
else:
|
|
210
449
|
logger.warning("No reflected results found, falling back to original results")
|
|
211
450
|
|
|
@@ -223,10 +462,126 @@ class MultiReflectionJudge(BaseJudge):
|
|
|
223
462
|
"final_results": final_results,
|
|
224
463
|
"final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
|
|
225
464
|
"final_max_idx": final_max_idx,
|
|
465
|
+
"final_rollout_metadata": final_rollout_metadata,
|
|
226
466
|
# Also include initial measurements
|
|
227
467
|
"indep_results": indep_results,
|
|
228
468
|
"indep_max_idx": indep_max_idx,
|
|
229
469
|
"indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
|
|
470
|
+
"indep_rollout_metadata": indep_rollout_metadata,
|
|
230
471
|
},
|
|
231
472
|
result_type=ResultType.DIRECT_RESULT,
|
|
232
473
|
)
|
|
474
|
+
|
|
475
|
+
async def estimate_output_distrs(
|
|
476
|
+
self,
|
|
477
|
+
agent_run: AgentRun,
|
|
478
|
+
*,
|
|
479
|
+
n_initial_rollouts_to_sample: int,
|
|
480
|
+
n_combinations_to_sample: int,
|
|
481
|
+
n_reflection_rollouts_to_sample: int,
|
|
482
|
+
**kwargs: Any,
|
|
483
|
+
) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
|
|
484
|
+
if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
|
|
485
|
+
raise ValueError(
|
|
486
|
+
"n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
|
|
487
|
+
)
|
|
488
|
+
if self.cfg.n_rollouts_per_input > n_reflection_rollouts_to_sample:
|
|
489
|
+
raise ValueError(
|
|
490
|
+
"n_reflection_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
first_step_rollouts: list[dict[str, Any]] = []
|
|
494
|
+
first_step_rollout_metadata: list[dict[str, Any] | None] = []
|
|
495
|
+
first_step_combinations: list[list[dict[str, Any]]] = []
|
|
496
|
+
second_step_rollouts: list[list[dict[str, Any]]] = []
|
|
497
|
+
second_step_rollout_metadata: list[list[dict[str, Any] | None]] = []
|
|
498
|
+
|
|
499
|
+
##########
|
|
500
|
+
# Step 1 #
|
|
501
|
+
##########
|
|
502
|
+
|
|
503
|
+
pbar_first = tqdm(
|
|
504
|
+
total=n_initial_rollouts_to_sample, desc="Stage 1: Initial rollouts", leave=False
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
async def _execute_first_stage():
|
|
508
|
+
result, metadata = await self.one_rollout(agent_run)
|
|
509
|
+
if result is not None:
|
|
510
|
+
first_step_rollouts.append(result)
|
|
511
|
+
first_step_rollout_metadata.append(metadata)
|
|
512
|
+
pbar_first.update(1)
|
|
513
|
+
|
|
514
|
+
# Collect rollouts of the first stage
|
|
515
|
+
async with anyio.create_task_group() as tg_first:
|
|
516
|
+
for _ in range(n_initial_rollouts_to_sample):
|
|
517
|
+
tg_first.start_soon(_execute_first_stage)
|
|
518
|
+
|
|
519
|
+
pbar_first.close()
|
|
520
|
+
|
|
521
|
+
if len(first_step_rollouts) < self.cfg.n_rollouts_per_input:
|
|
522
|
+
raise ValueError("Not enough first step rollouts to sample combinations")
|
|
523
|
+
|
|
524
|
+
# Sample random k-sized combinations of the first step rollouts
|
|
525
|
+
for _ in range(n_combinations_to_sample):
|
|
526
|
+
combination = random.sample(first_step_rollouts, self.cfg.n_rollouts_per_input)
|
|
527
|
+
first_step_combinations.append(combination)
|
|
528
|
+
second_step_rollouts.append([])
|
|
529
|
+
second_step_rollout_metadata.append([])
|
|
530
|
+
|
|
531
|
+
##########
|
|
532
|
+
# Step 2 #
|
|
533
|
+
##########
|
|
534
|
+
|
|
535
|
+
pbar_second = tqdm(
|
|
536
|
+
total=n_combinations_to_sample, desc="Stage 2: Combinations", leave=False
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
async with anyio.create_task_group() as tg_second:
|
|
540
|
+
|
|
541
|
+
async def _execute_second_stage(i: int, combination: list[dict[str, Any]]):
|
|
542
|
+
pbar_third = tqdm(
|
|
543
|
+
total=n_reflection_rollouts_to_sample,
|
|
544
|
+
desc=f"Stage 2: Combination {i+1}/{n_combinations_to_sample}",
|
|
545
|
+
leave=False,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
async def _execute_second_stage_inner():
|
|
549
|
+
result, metadata = await self.one_rollout_second_stage(agent_run, combination)
|
|
550
|
+
if result is not None:
|
|
551
|
+
second_step_rollouts[i].append(result)
|
|
552
|
+
second_step_rollout_metadata[i].append(metadata)
|
|
553
|
+
pbar_third.update(1)
|
|
554
|
+
|
|
555
|
+
async with anyio.create_task_group() as tg:
|
|
556
|
+
for _ in range(n_reflection_rollouts_to_sample):
|
|
557
|
+
tg.start_soon(_execute_second_stage_inner)
|
|
558
|
+
|
|
559
|
+
pbar_third.close()
|
|
560
|
+
pbar_second.update(1)
|
|
561
|
+
|
|
562
|
+
for i, combination in enumerate(first_step_combinations):
|
|
563
|
+
tg_second.start_soon(_execute_second_stage, i, combination)
|
|
564
|
+
|
|
565
|
+
pbar_second.close()
|
|
566
|
+
|
|
567
|
+
output_distributions = compute_output_distributions(
|
|
568
|
+
[sublist for el in second_step_rollouts for sublist in el],
|
|
569
|
+
self.cfg.output_schema,
|
|
570
|
+
get_agreement_keys(self.cfg.output_schema),
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
return output_distributions, {
|
|
574
|
+
"first_step_rollouts": first_step_rollouts,
|
|
575
|
+
"first_step_rollout_metadata": first_step_rollout_metadata,
|
|
576
|
+
"first_step_combinations": first_step_combinations,
|
|
577
|
+
"second_step_rollouts": second_step_rollouts,
|
|
578
|
+
"second_step_rollout_metadata": second_step_rollout_metadata,
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def build_judge(rubric: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None):
|
|
583
|
+
if rubric.judge_variant == JudgeVariant.MAJORITY:
|
|
584
|
+
return MajorityVotingJudge(rubric, llm_svc, docent_collection_id)
|
|
585
|
+
elif rubric.judge_variant == JudgeVariant.MULTI_REFLECT:
|
|
586
|
+
return MultiReflectionJudge(rubric, llm_svc, docent_collection_id)
|
|
587
|
+
raise ValueError(f"Invalid variant: {rubric.judge_variant}")
|