docent-python 0.1.21a0__py3-none-any.whl → 0.1.23a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

docent/judges/impl.py CHANGED
@@ -1,67 +1,275 @@
1
- import json
1
+ import random
2
+ import re
2
3
  from abc import ABC, abstractmethod
3
- from typing import Any
4
+ from contextlib import nullcontext
5
+ from typing import Any, Sequence
4
6
 
7
+ import anyio
8
+ import yaml
9
+ from pydantic_core import to_jsonable_python
10
+ from tqdm.auto import tqdm
11
+
12
+ from docent._llm_util.data_models.exceptions import ValidationFailedException
5
13
  from docent._llm_util.data_models.llm_output import LLMOutput
6
- from docent._llm_util.data_models.simple_svc import BaseLLMService, SimpleLLMService
14
+ from docent._llm_util.llm_svc import BaseLLMService
7
15
  from docent._log_util import get_logger
8
16
  from docent.data_models.agent_run import AgentRun
9
- from docent.judges.types import JudgeResult, ResultType, Rubric
10
- from docent.judges.util.parse_output import parse_and_validate_llm_output
11
- from docent.judges.util.voting import find_modal_result, get_agreement_keys
17
+ from docent.data_models.chat.message import (
18
+ AssistantMessage,
19
+ ChatMessage,
20
+ ToolMessage,
21
+ UserMessage,
22
+ )
23
+ from docent.data_models.chat.tool import ToolInfo
24
+ from docent.judges.types import JudgeResult, JudgeVariant, ResultType, Rubric
25
+ from docent.judges.util.parse_output import parse_and_validate_output_str
26
+ from docent.judges.util.voting import (
27
+ JudgeOutputDistribution,
28
+ compute_output_distributions,
29
+ find_modal_result,
30
+ get_agreement_keys,
31
+ )
32
+ from docent.trace import agent_run_context, agent_run_metadata
12
33
 
13
34
  logger = get_logger(__name__)
14
35
 
15
36
 
16
37
  class BaseJudge(ABC):
17
- def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
38
+ def __init__(
39
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
40
+ ):
18
41
  self.cfg = cfg
19
42
  self.llm_svc = llm_svc
43
+ self.docent_collection_id = docent_collection_id
20
44
 
21
45
  @abstractmethod
22
- async def __call__(self, agent_run: AgentRun, *args: Any, **kwargs: Any) -> JudgeResult | None:
46
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
23
47
  """Returns None if all rollouts failed to produce a valid output."""
24
48
 
49
+ @abstractmethod
50
+ async def estimate_output_distrs(
51
+ self, agent_run: AgentRun, **kwargs: Any
52
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
53
+ """Estimate the output distribution of each output key."""
25
54
 
26
- class MajorityVotingJudge(BaseJudge):
27
- """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
28
-
29
- def __init__(
30
- self,
31
- cfg: Rubric,
32
- n_rollouts_per_input: int,
33
- llm_svc: BaseLLMService = SimpleLLMService(),
34
- ):
35
- super().__init__(cfg, llm_svc)
36
- self.n_rollouts_per_input = n_rollouts_per_input
37
-
38
- async def __call__(
39
- self,
40
- agent_run: AgentRun,
41
- max_concurrency: int = 10,
42
- ) -> JudgeResult | None:
55
+ def _get_validation_callback(self, agent_run: AgentRun):
43
56
  async def _validation_callback(batch_index: int, llm_output: LLMOutput):
44
- parse_and_validate_llm_output(llm_output, self.cfg.output_schema, agent_run)
45
-
46
- prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
57
+ validated_output = self._validate_first_response_tag_or_entire_output(
58
+ llm_output.first_text or "", agent_run
59
+ )
60
+ if validated_output is None:
61
+ raise ValidationFailedException(
62
+ "Validation failed", failed_output=llm_output.first_text
63
+ )
64
+
65
+ return _validation_callback
66
+
67
+ async def one_rollout(
68
+ self, agent_run: AgentRun
69
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
70
+ with agent_run_context() if self.docent_collection_id is not None else nullcontext():
71
+ if self.cfg.rollout_type == "single_turn":
72
+ output, metadata = await self.one_single_turn_rollout(agent_run)
73
+ elif self.cfg.rollout_type == "multi_turn":
74
+ output, metadata = await self.one_multi_turn_rollout(
75
+ agent_run, max_turns=10, max_steps_per_turn=5
76
+ )
77
+ else:
78
+ raise ValueError(f"Invalid rollout type: {self.cfg.rollout_type}")
79
+
80
+ if self.docent_collection_id is not None:
81
+ agent_run_metadata(
82
+ {
83
+ "agent_run_id": agent_run.id,
84
+ "judge_output": output,
85
+ "judge_rollout_metadata": to_jsonable_python(metadata),
86
+ }
87
+ )
88
+
89
+ return output, metadata
90
+
91
+ def _validate_first_response_tag_or_entire_output(
92
+ self, output_str: str, agent_run: AgentRun
93
+ ) -> dict[str, Any] | None:
94
+ """Validate the first <response> tag in the output string.
95
+ For backward compatibility, also try to validate the entire output as JSON, for
96
+ old system prompts that don't ask for <response> tags.
97
+
98
+ Args:
99
+ output_str: The output string to validate
100
+ agent_run: The agent run to validate against
101
+
102
+ Returns:
103
+ The validated output if successful, None otherwise
104
+ """
105
+ response_matches = re.findall(r"<response>(.*?)</response>", output_str, re.DOTALL)
106
+
107
+ # Try to validate any match; take the first
108
+ for response_text in response_matches:
109
+ try:
110
+ validated_output = parse_and_validate_output_str(
111
+ response_text, self.cfg.output_schema, agent_run
112
+ )
113
+ return validated_output
114
+ except ValidationFailedException:
115
+ continue # Try the next match if validation fails
116
+
117
+ # Try to validate the entire output as JSON
118
+ # But only if the output _didn't_ contain a <response>...</response> tag
119
+ if not response_matches:
120
+ try:
121
+ validated_output = parse_and_validate_output_str(
122
+ output_str, self.cfg.output_schema, agent_run
123
+ )
124
+ return validated_output
125
+ except ValidationFailedException:
126
+ pass
127
+
128
+ return None
129
+
130
+ ########################
131
+ # Single turn rollouts #
132
+ ########################
133
+
134
+ async def one_single_turn_rollout(
135
+ self, agent_run: AgentRun
136
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
137
+ prompt = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
47
138
  outputs = await self.llm_svc.get_completions(
48
- inputs=[prompt for _ in range(self.n_rollouts_per_input)],
139
+ inputs=[prompt],
49
140
  model_options=[self.cfg.judge_model],
50
141
  max_new_tokens=16384,
51
142
  timeout=180.0,
52
143
  use_cache=False,
53
- validation_callback=_validation_callback,
54
- max_concurrency=max_concurrency,
144
+ validation_callback=self._get_validation_callback(agent_run),
55
145
  )
146
+ output_str = outputs[0].first_text
56
147
 
57
- # Process each rollout independently
58
- indep_results: list[dict[str, Any]] = []
59
- for output in outputs:
60
- if validated_output := parse_and_validate_llm_output(
61
- output, self.cfg.output_schema, agent_run
62
- ):
63
- indep_results.append(validated_output)
148
+ # Extract all <response>...</response> tags from the current message
149
+ validated_output = self._validate_first_response_tag_or_entire_output(
150
+ output_str or "", agent_run
151
+ )
152
+ if validated_output is not None:
153
+ return validated_output, {"full_output": output_str}
154
+ else:
155
+ return None, None
156
+
157
+ #######################
158
+ # Multi-turn rollouts #
159
+ #######################
160
+
161
+ async def one_multi_turn_rollout(
162
+ self, agent_run: AgentRun, max_turns: int, max_steps_per_turn: int
163
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
164
+ msgs = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
165
+ for _ in range(max_turns):
166
+ msgs = await self.agent_one_turn(msgs, max_steps_per_turn=max_steps_per_turn)
167
+
168
+ last_msg_content = msgs[-1].text if msgs else None
169
+ # Extract all <response>...</response> tags from the current message
170
+ # Return if we find a valid response; otherwise, continue
171
+ validated_output = self._validate_first_response_tag_or_entire_output(
172
+ last_msg_content or "", agent_run
173
+ )
174
+ if validated_output is not None:
175
+ # When returning, strip out the system message, which duplicates the agent run
176
+ # content many times.
177
+ return validated_output, {"rollout_messages": msgs[1:]}
178
+
179
+ # No <response>...</response> tags with valid JSON,so return None
180
+ return None, None
181
+
182
+ async def agent_one_turn(self, init_msgs: Sequence[ChatMessage], max_steps_per_turn: int):
183
+ """Given a list of messages, run one turn of the agent.
184
+ The agent may invoke tools, so we loop until there are no more to handle.
185
+ """
186
+
187
+ msgs = list(init_msgs) # Shallow copy is fine
188
+ for _ in range(max_steps_per_turn):
189
+ last_msg = msgs[-1]
190
+ if last_msg.role == "system" or last_msg.role == "user" or last_msg.role == "tool":
191
+ outputs = await self.llm_svc.get_completions(
192
+ inputs=[msgs],
193
+ model_options=[self.cfg.judge_model],
194
+ tools=[
195
+ ToolInfo(
196
+ name="step_finished",
197
+ description="Call this tool to indicate that you have finished one step in the decision procedure",
198
+ )
199
+ ],
200
+ max_new_tokens=16384,
201
+ timeout=180.0,
202
+ use_cache=False,
203
+ )
204
+ output = outputs[0].first
205
+ if output is None:
206
+ # FIXME(mengk): handle empty completion
207
+ raise ValueError("Empty completion in agent one turn")
208
+ new_assistant_msg = AssistantMessage(
209
+ content=output.text or "", tool_calls=output.tool_calls
210
+ )
211
+ msgs.append(new_assistant_msg)
212
+ elif last_msg.role == "assistant":
213
+ if last_msg.tool_calls is not None:
214
+ msgs.extend(
215
+ [
216
+ ToolMessage(
217
+ content="Step completed",
218
+ tool_call_id=tool_call.id,
219
+ )
220
+ for tool_call in last_msg.tool_calls
221
+ ]
222
+ )
223
+ else:
224
+ break # Terminate if there are no more tool calls to handle
225
+ else:
226
+ raise ValueError(f"Unknown message role: {last_msg.role}")
227
+ return msgs
228
+
229
+
230
+ class SingleRolloutJudge(BaseJudge):
231
+ """Rolls out the judge once."""
232
+
233
+ def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
234
+ super().__init__(cfg, llm_svc)
64
235
 
236
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
237
+ output, metadata = await self.one_rollout(agent_run)
238
+ if output is None:
239
+ return None
240
+ else:
241
+ return JudgeResult(
242
+ agent_run_id=agent_run.id,
243
+ rubric_id=self.cfg.id,
244
+ rubric_version=self.cfg.version,
245
+ output=output,
246
+ result_metadata={"rollout_metadata": metadata},
247
+ result_type=ResultType.DIRECT_RESULT,
248
+ )
249
+
250
+
251
+ class MajorityVotingJudge(BaseJudge):
252
+ """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
253
+
254
+ def __init__(
255
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
256
+ ):
257
+ super().__init__(cfg, llm_svc, docent_collection_id)
258
+
259
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
260
+ indep_results: list[dict[str, Any]] = []
261
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
262
+
263
+ async def _execute():
264
+ result, metadata = await self.one_rollout(agent_run)
265
+ if result is not None:
266
+ indep_results.append(result)
267
+ indep_rollout_metadata.append(metadata)
268
+
269
+ # Run rollouts concurrently
270
+ async with anyio.create_task_group() as tg:
271
+ for _ in range(self.cfg.n_rollouts_per_input):
272
+ tg.start_soon(_execute)
65
273
  if not indep_results:
66
274
  return None
67
275
 
@@ -74,6 +282,11 @@ class MajorityVotingJudge(BaseJudge):
74
282
  )
75
283
  final_output = indep_results[final_max_idx]
76
284
 
285
+ # Compute the distribution of the output across the agreement keys
286
+ final_output_distributions = compute_output_distributions(
287
+ indep_results, self.cfg.output_schema, agreement_keys
288
+ )
289
+
77
290
  return JudgeResult(
78
291
  agent_run_id=agent_run.id,
79
292
  rubric_id=self.cfg.id,
@@ -85,53 +298,122 @@ class MajorityVotingJudge(BaseJudge):
85
298
  "final_results": indep_results,
86
299
  "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
87
300
  "final_max_idx": final_max_idx,
301
+ "final_output_distributions": final_output_distributions,
302
+ "final_rollout_metadata": indep_rollout_metadata,
88
303
  },
89
304
  result_type=ResultType.DIRECT_RESULT,
90
305
  )
91
306
 
307
+ async def estimate_output_distrs(
308
+ self, agent_run: AgentRun, *, n_initial_rollouts_to_sample: int, **kwargs: Any
309
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
310
+ if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
311
+ raise ValueError(
312
+ "n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
313
+ )
314
+
315
+ indep_results: list[dict[str, Any]] = []
316
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
317
+ pbar = tqdm(total=n_initial_rollouts_to_sample, desc="Independent rollouts", leave=False)
318
+
319
+ async def _execute():
320
+ result, metadata = await self.one_rollout(agent_run)
321
+ if result is not None:
322
+ indep_results.append(result)
323
+ indep_rollout_metadata.append(metadata)
324
+ pbar.update(1)
325
+
326
+ # Run rollouts concurrently
327
+ async with anyio.create_task_group() as tg:
328
+ for _ in range(n_initial_rollouts_to_sample):
329
+ tg.start_soon(_execute)
330
+
331
+ pbar.close()
332
+
333
+ if not indep_results:
334
+ return None
335
+
336
+ # Compute the probability vector for each agreement key
337
+ distributions = compute_output_distributions(
338
+ indep_results, self.cfg.output_schema, get_agreement_keys(self.cfg.output_schema)
339
+ )
340
+
341
+ return distributions, {
342
+ "first_step_rollouts": indep_results,
343
+ "first_step_rollout_metadata": indep_rollout_metadata,
344
+ }
345
+
92
346
 
93
347
  class MultiReflectionJudge(BaseJudge):
94
348
  """Rolls out the judge multiple times, then uses reflection to determine the final result."""
95
349
 
96
350
  def __init__(
97
- self,
98
- cfg: Rubric,
99
- n_rollouts_per_input: int,
100
- llm_svc: BaseLLMService = SimpleLLMService(),
351
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
101
352
  ):
102
- super().__init__(cfg, llm_svc)
103
- self.n_rollouts_per_input = n_rollouts_per_input
104
-
105
- async def __call__(
106
- self,
107
- agent_run: AgentRun,
108
- max_concurrency: int = 10,
109
- ) -> JudgeResult | None:
110
- rubric = self.cfg
111
-
112
- async def _validation_callback(batch_index: int, llm_output: LLMOutput):
113
- parse_and_validate_llm_output(llm_output, rubric.output_schema, agent_run)
114
-
115
- # Run several independent rollouts
116
- prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
353
+ super().__init__(cfg, llm_svc, docent_collection_id)
354
+
355
+ async def one_rollout_second_stage(
356
+ self, agent_run: AgentRun, first_stage_results: list[dict[str, Any]]
357
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
358
+ """Reflect on the results of the first stage of rollouts.
359
+ TODO(mengk): this is only done in a single-turn way. We should generalize this to multi-turn.
360
+ """
361
+
362
+ # Construct *single* reflection prompt
363
+ first_stage_results_text = "\n\n".join(
364
+ [
365
+ f"Rollout {j+1}:\n{yaml.dump(r, width=float('inf'))}"
366
+ for j, r in enumerate(first_stage_results)
367
+ ]
368
+ )
369
+ reflection_instruction = (
370
+ f"We have sampled a judge {len(first_stage_results)} times to get {len(first_stage_results)} independent answers to the same rubric evaluation:\n"
371
+ f"{first_stage_results_text}\n\n"
372
+ f"Please reflect on these answers. Consider all the information and evidence presented. "
373
+ f"Return a final answer in the same JSON format as before."
374
+ )
375
+ reflection_prompt = [
376
+ # Original system prompt
377
+ {"role": "system", "content": self.cfg.materialize_system_prompt(agent_run)},
378
+ # Additional reflection instruction as a user message (kind of awkward)
379
+ {"role": "user", "content": reflection_instruction},
380
+ ]
381
+
382
+ # Ask the judge to reflect on the others' results
117
383
  outputs = await self.llm_svc.get_completions(
118
- inputs=[prompt for _ in range(self.n_rollouts_per_input)],
119
- model_options=[rubric.judge_model],
384
+ inputs=[reflection_prompt],
385
+ model_options=[self.cfg.judge_model],
120
386
  max_new_tokens=16384,
121
387
  timeout=180.0,
122
388
  use_cache=False,
123
- validation_callback=_validation_callback,
124
- max_concurrency=max_concurrency,
389
+ validation_callback=self._get_validation_callback(agent_run),
125
390
  )
391
+ output_str = outputs[0].first_text
126
392
 
127
- # Process each rollout
128
- indep_results: list[dict[str, Any]] = []
129
- for output in outputs:
130
- if output.first_text is None:
131
- continue
132
- if v_output := parse_and_validate_llm_output(output, rubric.output_schema, agent_run):
133
- indep_results.append(v_output)
393
+ validated_output = self._validate_first_response_tag_or_entire_output(
394
+ output_str or "", agent_run
395
+ )
396
+ if validated_output is not None:
397
+ return validated_output, None
398
+ else:
399
+ return None, None
134
400
 
401
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
402
+ rubric = self.cfg
403
+
404
+ indep_results: list[dict[str, Any]] = []
405
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
406
+
407
+ async def _execute():
408
+ result, metadata = await self.one_rollout(agent_run)
409
+ if result is not None:
410
+ indep_results.append(result)
411
+ indep_rollout_metadata.append(metadata)
412
+
413
+ # Stage 1: run rollouts concurrently
414
+ async with anyio.create_task_group() as tg:
415
+ for _ in range(self.cfg.n_rollouts_per_input):
416
+ tg.start_soon(_execute)
135
417
  if not indep_results:
136
418
  return None
137
419
 
@@ -141,61 +423,28 @@ class MultiReflectionJudge(BaseJudge):
141
423
  indep_results, agreement_keys
142
424
  )
143
425
 
144
- def _get_reflection_prompt(cur_index: int):
145
- # Current result
146
- result = indep_results[cur_index]
147
- # Get other results (excluding the current one)
148
- other_results = [r for j, r in enumerate(indep_results) if j != cur_index]
149
-
150
- # Create the reflection message
151
- other_results_text = "\n\n".join(
152
- [f"Answer {j+1}:\n{json.dumps(r, indent=2)}" for j, r in enumerate(other_results)]
153
- )
154
-
155
- reflection_instruction = (
156
- f"Here are {len(other_results)} other independent answers to the same rubric evaluation:\n\n"
157
- f"{other_results_text}\n\n"
158
- f"Please reflect on these other answers and your own answer. "
159
- f"Consider if any of them have identified important aspects you missed, or if there are disagreements that should be resolved. "
160
- f"Then provide your final answer in the same JSON format as before."
161
- )
162
-
163
- # Construct the multi-message prompt
164
- # 1. Original user message
165
- # 2. Assistant message with the rollout's result
166
- # 3. New user message asking for reflection
167
- return [
168
- *prompt, # Original user message(s)
169
- {"role": "assistant", "content": json.dumps(result, indent=2)},
170
- {"role": "user", "content": reflection_instruction},
171
- ]
172
-
173
- final_results = indep_results.copy() # Shallow copy
426
+ # Stage 2: reflect on the results
427
+ # Shallow copies are fine
428
+ final_results = indep_results.copy()
429
+ final_rollout_metadata = indep_rollout_metadata.copy()
174
430
  if len(indep_results) > 1:
175
- # Ask the judge to reflect on the others' results
176
- reflection_outputs = await self.llm_svc.get_completions(
177
- inputs=[_get_reflection_prompt(i) for i in range(len(indep_results))],
178
- model_options=[rubric.judge_model],
179
- max_new_tokens=16384,
180
- timeout=180.0,
181
- use_cache=False,
182
- validation_callback=_validation_callback,
183
- max_concurrency=max_concurrency,
184
- )
431
+ candidate_final_results: list[dict[str, Any]] = []
432
+ candidate_final_rollout_metadata: list[dict[str, Any] | None] = []
433
+
434
+ async def _execute_second_stage():
435
+ result, metadata = await self.one_rollout_second_stage(agent_run, indep_results)
436
+ if result is not None:
437
+ candidate_final_results.append(result)
438
+ candidate_final_rollout_metadata.append(metadata)
185
439
 
186
- # Process reflection outputs in the same way as the initial rollouts
187
- reflected_results: list[dict[str, Any]] = []
188
- for output in reflection_outputs:
189
- if output.first_text is None:
190
- continue
191
- if v_output := parse_and_validate_llm_output(
192
- output, rubric.output_schema, agent_run
193
- ):
194
- reflected_results.append(v_output)
440
+ async with anyio.create_task_group() as tg:
441
+ for _ in range(self.cfg.n_rollouts_per_input):
442
+ tg.start_soon(_execute_second_stage)
195
443
 
196
444
  # Use reflected results if we got any, otherwise fall back to original results
197
- if reflected_results:
198
- final_results = reflected_results
445
+ if candidate_final_results:
446
+ final_results = candidate_final_results
447
+ final_rollout_metadata = candidate_final_rollout_metadata
199
448
  else:
200
449
  logger.warning("No reflected results found, falling back to original results")
201
450
 
@@ -213,10 +462,126 @@ class MultiReflectionJudge(BaseJudge):
213
462
  "final_results": final_results,
214
463
  "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
215
464
  "final_max_idx": final_max_idx,
465
+ "final_rollout_metadata": final_rollout_metadata,
216
466
  # Also include initial measurements
217
467
  "indep_results": indep_results,
218
468
  "indep_max_idx": indep_max_idx,
219
469
  "indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
470
+ "indep_rollout_metadata": indep_rollout_metadata,
220
471
  },
221
472
  result_type=ResultType.DIRECT_RESULT,
222
473
  )
474
+
475
+ async def estimate_output_distrs(
476
+ self,
477
+ agent_run: AgentRun,
478
+ *,
479
+ n_initial_rollouts_to_sample: int,
480
+ n_combinations_to_sample: int,
481
+ n_reflection_rollouts_to_sample: int,
482
+ **kwargs: Any,
483
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
484
+ if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
485
+ raise ValueError(
486
+ "n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
487
+ )
488
+ if self.cfg.n_rollouts_per_input > n_reflection_rollouts_to_sample:
489
+ raise ValueError(
490
+ "n_reflection_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
491
+ )
492
+
493
+ first_step_rollouts: list[dict[str, Any]] = []
494
+ first_step_rollout_metadata: list[dict[str, Any] | None] = []
495
+ first_step_combinations: list[list[dict[str, Any]]] = []
496
+ second_step_rollouts: list[list[dict[str, Any]]] = []
497
+ second_step_rollout_metadata: list[list[dict[str, Any] | None]] = []
498
+
499
+ ##########
500
+ # Step 1 #
501
+ ##########
502
+
503
+ pbar_first = tqdm(
504
+ total=n_initial_rollouts_to_sample, desc="Stage 1: Initial rollouts", leave=False
505
+ )
506
+
507
+ async def _execute_first_stage():
508
+ result, metadata = await self.one_rollout(agent_run)
509
+ if result is not None:
510
+ first_step_rollouts.append(result)
511
+ first_step_rollout_metadata.append(metadata)
512
+ pbar_first.update(1)
513
+
514
+ # Collect rollouts of the first stage
515
+ async with anyio.create_task_group() as tg_first:
516
+ for _ in range(n_initial_rollouts_to_sample):
517
+ tg_first.start_soon(_execute_first_stage)
518
+
519
+ pbar_first.close()
520
+
521
+ if len(first_step_rollouts) < self.cfg.n_rollouts_per_input:
522
+ raise ValueError("Not enough first step rollouts to sample combinations")
523
+
524
+ # Sample random k-sized combinations of the first step rollouts
525
+ for _ in range(n_combinations_to_sample):
526
+ combination = random.sample(first_step_rollouts, self.cfg.n_rollouts_per_input)
527
+ first_step_combinations.append(combination)
528
+ second_step_rollouts.append([])
529
+ second_step_rollout_metadata.append([])
530
+
531
+ ##########
532
+ # Step 2 #
533
+ ##########
534
+
535
+ pbar_second = tqdm(
536
+ total=n_combinations_to_sample, desc="Stage 2: Combinations", leave=False
537
+ )
538
+
539
+ async with anyio.create_task_group() as tg_second:
540
+
541
+ async def _execute_second_stage(i: int, combination: list[dict[str, Any]]):
542
+ pbar_third = tqdm(
543
+ total=n_reflection_rollouts_to_sample,
544
+ desc=f"Stage 2: Combination {i+1}/{n_combinations_to_sample}",
545
+ leave=False,
546
+ )
547
+
548
+ async def _execute_second_stage_inner():
549
+ result, metadata = await self.one_rollout_second_stage(agent_run, combination)
550
+ if result is not None:
551
+ second_step_rollouts[i].append(result)
552
+ second_step_rollout_metadata[i].append(metadata)
553
+ pbar_third.update(1)
554
+
555
+ async with anyio.create_task_group() as tg:
556
+ for _ in range(n_reflection_rollouts_to_sample):
557
+ tg.start_soon(_execute_second_stage_inner)
558
+
559
+ pbar_third.close()
560
+ pbar_second.update(1)
561
+
562
+ for i, combination in enumerate(first_step_combinations):
563
+ tg_second.start_soon(_execute_second_stage, i, combination)
564
+
565
+ pbar_second.close()
566
+
567
+ output_distributions = compute_output_distributions(
568
+ [sublist for el in second_step_rollouts for sublist in el],
569
+ self.cfg.output_schema,
570
+ get_agreement_keys(self.cfg.output_schema),
571
+ )
572
+
573
+ return output_distributions, {
574
+ "first_step_rollouts": first_step_rollouts,
575
+ "first_step_rollout_metadata": first_step_rollout_metadata,
576
+ "first_step_combinations": first_step_combinations,
577
+ "second_step_rollouts": second_step_rollouts,
578
+ "second_step_rollout_metadata": second_step_rollout_metadata,
579
+ }
580
+
581
+
582
+ def build_judge(rubric: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None):
583
+ if rubric.judge_variant == JudgeVariant.MAJORITY:
584
+ return MajorityVotingJudge(rubric, llm_svc, docent_collection_id)
585
+ elif rubric.judge_variant == JudgeVariant.MULTI_REFLECT:
586
+ return MultiReflectionJudge(rubric, llm_svc, docent_collection_id)
587
+ raise ValueError(f"Invalid variant: {rubric.judge_variant}")