docent-python 0.1.22a0__py3-none-any.whl → 0.1.24a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

docent/judges/impl.py CHANGED
@@ -1,71 +1,275 @@
1
- import json
1
+ import random
2
+ import re
2
3
  from abc import ABC, abstractmethod
3
- from typing import Any
4
+ from contextlib import nullcontext
5
+ from typing import Any, Sequence
4
6
 
7
+ import anyio
8
+ import yaml
9
+ from pydantic_core import to_jsonable_python
10
+ from tqdm.auto import tqdm
11
+
12
+ from docent._llm_util.data_models.exceptions import ValidationFailedException
5
13
  from docent._llm_util.data_models.llm_output import LLMOutput
6
- from docent._llm_util.data_models.simple_svc import BaseLLMService, SimpleLLMService
14
+ from docent._llm_util.llm_svc import BaseLLMService
7
15
  from docent._log_util import get_logger
8
16
  from docent.data_models.agent_run import AgentRun
9
- from docent.judges.types import JudgeResult, ResultType, Rubric
10
- from docent.judges.util.parse_output import parse_and_validate_llm_output
17
+ from docent.data_models.chat.message import (
18
+ AssistantMessage,
19
+ ChatMessage,
20
+ ToolMessage,
21
+ UserMessage,
22
+ )
23
+ from docent.data_models.chat.tool import ToolInfo
24
+ from docent.judges.types import JudgeResult, JudgeVariant, ResultType, Rubric
25
+ from docent.judges.util.parse_output import parse_and_validate_output_str
11
26
  from docent.judges.util.voting import (
12
- compute_output_distribution,
27
+ JudgeOutputDistribution,
28
+ compute_output_distributions,
13
29
  find_modal_result,
14
30
  get_agreement_keys,
15
31
  )
32
+ from docent.trace import agent_run_context, agent_run_metadata
16
33
 
17
34
  logger = get_logger(__name__)
18
35
 
19
36
 
20
37
  class BaseJudge(ABC):
21
- def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
38
+ def __init__(
39
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
40
+ ):
22
41
  self.cfg = cfg
23
42
  self.llm_svc = llm_svc
43
+ self.docent_collection_id = docent_collection_id
24
44
 
25
45
  @abstractmethod
26
- async def __call__(self, agent_run: AgentRun, *args: Any, **kwargs: Any) -> JudgeResult | None:
46
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
27
47
  """Returns None if all rollouts failed to produce a valid output."""
28
48
 
49
+ @abstractmethod
50
+ async def estimate_output_distrs(
51
+ self, agent_run: AgentRun, **kwargs: Any
52
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
53
+ """Estimate the output distribution of each output key."""
29
54
 
30
- class MajorityVotingJudge(BaseJudge):
31
- """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
32
-
33
- def __init__(
34
- self,
35
- cfg: Rubric,
36
- n_rollouts_per_input: int,
37
- llm_svc: BaseLLMService = SimpleLLMService(),
38
- ):
39
- super().__init__(cfg, llm_svc)
40
- self.n_rollouts_per_input = n_rollouts_per_input
41
-
42
- async def __call__(
43
- self,
44
- agent_run: AgentRun,
45
- max_concurrency: int = 10,
46
- ) -> JudgeResult | None:
55
+ def _get_validation_callback(self, agent_run: AgentRun):
47
56
  async def _validation_callback(batch_index: int, llm_output: LLMOutput):
48
- parse_and_validate_llm_output(llm_output, self.cfg.output_schema, agent_run)
49
-
50
- prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
57
+ validated_output = self._validate_first_response_tag_or_entire_output(
58
+ llm_output.first_text or "", agent_run
59
+ )
60
+ if validated_output is None:
61
+ raise ValidationFailedException(
62
+ "Validation failed", failed_output=llm_output.first_text
63
+ )
64
+
65
+ return _validation_callback
66
+
67
+ async def one_rollout(
68
+ self, agent_run: AgentRun
69
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
70
+ with agent_run_context() if self.docent_collection_id is not None else nullcontext():
71
+ if self.cfg.rollout_type == "single_turn":
72
+ output, metadata = await self.one_single_turn_rollout(agent_run)
73
+ elif self.cfg.rollout_type == "multi_turn":
74
+ output, metadata = await self.one_multi_turn_rollout(
75
+ agent_run, max_turns=10, max_steps_per_turn=5
76
+ )
77
+ else:
78
+ raise ValueError(f"Invalid rollout type: {self.cfg.rollout_type}")
79
+
80
+ if self.docent_collection_id is not None:
81
+ agent_run_metadata(
82
+ {
83
+ "agent_run_id": agent_run.id,
84
+ "judge_output": output,
85
+ "judge_rollout_metadata": to_jsonable_python(metadata),
86
+ }
87
+ )
88
+
89
+ return output, metadata
90
+
91
+ def _validate_first_response_tag_or_entire_output(
92
+ self, output_str: str, agent_run: AgentRun
93
+ ) -> dict[str, Any] | None:
94
+ """Validate the first <response> tag in the output string.
95
+ For backward compatibility, also try to validate the entire output as JSON, for
96
+ old system prompts that don't ask for <response> tags.
97
+
98
+ Args:
99
+ output_str: The output string to validate
100
+ agent_run: The agent run to validate against
101
+
102
+ Returns:
103
+ The validated output if successful, None otherwise
104
+ """
105
+ response_matches = re.findall(r"<response>(.*?)</response>", output_str, re.DOTALL)
106
+
107
+ # Try to validate any match; take the first
108
+ for response_text in response_matches:
109
+ try:
110
+ validated_output = parse_and_validate_output_str(
111
+ response_text, self.cfg.output_schema, agent_run
112
+ )
113
+ return validated_output
114
+ except ValidationFailedException:
115
+ continue # Try the next match if validation fails
116
+
117
+ # Try to validate the entire output as JSON
118
+ # But only if the output _didn't_ contain a <response>...</response> tag
119
+ if not response_matches:
120
+ try:
121
+ validated_output = parse_and_validate_output_str(
122
+ output_str, self.cfg.output_schema, agent_run
123
+ )
124
+ return validated_output
125
+ except ValidationFailedException:
126
+ pass
127
+
128
+ return None
129
+
130
+ ########################
131
+ # Single turn rollouts #
132
+ ########################
133
+
134
+ async def one_single_turn_rollout(
135
+ self, agent_run: AgentRun
136
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
137
+ prompt = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
51
138
  outputs = await self.llm_svc.get_completions(
52
- inputs=[prompt for _ in range(self.n_rollouts_per_input)],
139
+ inputs=[prompt],
53
140
  model_options=[self.cfg.judge_model],
54
141
  max_new_tokens=16384,
55
142
  timeout=180.0,
56
143
  use_cache=False,
57
- validation_callback=_validation_callback,
58
- max_concurrency=max_concurrency,
144
+ validation_callback=self._get_validation_callback(agent_run),
59
145
  )
146
+ output_str = outputs[0].first_text
147
+
148
+ # Extract all <response>...</response> tags from the current message
149
+ validated_output = self._validate_first_response_tag_or_entire_output(
150
+ output_str or "", agent_run
151
+ )
152
+ if validated_output is not None:
153
+ return validated_output, {"full_output": output_str}
154
+ else:
155
+ return None, None
156
+
157
+ #######################
158
+ # Multi-turn rollouts #
159
+ #######################
160
+
161
+ async def one_multi_turn_rollout(
162
+ self, agent_run: AgentRun, max_turns: int, max_steps_per_turn: int
163
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
164
+ msgs = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
165
+ for _ in range(max_turns):
166
+ msgs = await self.agent_one_turn(msgs, max_steps_per_turn=max_steps_per_turn)
167
+
168
+ last_msg_content = msgs[-1].text if msgs else None
169
+ # Extract all <response>...</response> tags from the current message
170
+ # Return if we find a valid response; otherwise, continue
171
+ validated_output = self._validate_first_response_tag_or_entire_output(
172
+ last_msg_content or "", agent_run
173
+ )
174
+ if validated_output is not None:
175
+ # When returning, strip out the system message, which duplicates the agent run
176
+ # content many times.
177
+ return validated_output, {"rollout_messages": msgs[1:]}
178
+
179
+ # No <response>...</response> tags with valid JSON,so return None
180
+ return None, None
181
+
182
+ async def agent_one_turn(self, init_msgs: Sequence[ChatMessage], max_steps_per_turn: int):
183
+ """Given a list of messages, run one turn of the agent.
184
+ The agent may invoke tools, so we loop until there are no more to handle.
185
+ """
186
+
187
+ msgs = list(init_msgs) # Shallow copy is fine
188
+ for _ in range(max_steps_per_turn):
189
+ last_msg = msgs[-1]
190
+ if last_msg.role == "system" or last_msg.role == "user" or last_msg.role == "tool":
191
+ outputs = await self.llm_svc.get_completions(
192
+ inputs=[msgs],
193
+ model_options=[self.cfg.judge_model],
194
+ tools=[
195
+ ToolInfo(
196
+ name="step_finished",
197
+ description="Call this tool to indicate that you have finished one step in the decision procedure",
198
+ )
199
+ ],
200
+ max_new_tokens=16384,
201
+ timeout=180.0,
202
+ use_cache=False,
203
+ )
204
+ output = outputs[0].first
205
+ if output is None:
206
+ # FIXME(mengk): handle empty completion
207
+ raise ValueError("Empty completion in agent one turn")
208
+ new_assistant_msg = AssistantMessage(
209
+ content=output.text or "", tool_calls=output.tool_calls
210
+ )
211
+ msgs.append(new_assistant_msg)
212
+ elif last_msg.role == "assistant":
213
+ if last_msg.tool_calls is not None:
214
+ msgs.extend(
215
+ [
216
+ ToolMessage(
217
+ content="Step completed",
218
+ tool_call_id=tool_call.id,
219
+ )
220
+ for tool_call in last_msg.tool_calls
221
+ ]
222
+ )
223
+ else:
224
+ break # Terminate if there are no more tool calls to handle
225
+ else:
226
+ raise ValueError(f"Unknown message role: {last_msg.role}")
227
+ return msgs
60
228
 
61
- # Process each rollout independently
62
- indep_results: list[dict[str, Any]] = []
63
- for output in outputs:
64
- if validated_output := parse_and_validate_llm_output(
65
- output, self.cfg.output_schema, agent_run
66
- ):
67
- indep_results.append(validated_output)
68
229
 
230
+ class SingleRolloutJudge(BaseJudge):
231
+ """Rolls out the judge once."""
232
+
233
+ def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
234
+ super().__init__(cfg, llm_svc)
235
+
236
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
237
+ output, metadata = await self.one_rollout(agent_run)
238
+ if output is None:
239
+ return None
240
+ else:
241
+ return JudgeResult(
242
+ agent_run_id=agent_run.id,
243
+ rubric_id=self.cfg.id,
244
+ rubric_version=self.cfg.version,
245
+ output=output,
246
+ result_metadata={"rollout_metadata": metadata},
247
+ result_type=ResultType.DIRECT_RESULT,
248
+ )
249
+
250
+
251
+ class MajorityVotingJudge(BaseJudge):
252
+ """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
253
+
254
+ def __init__(
255
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
256
+ ):
257
+ super().__init__(cfg, llm_svc, docent_collection_id)
258
+
259
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
260
+ indep_results: list[dict[str, Any]] = []
261
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
262
+
263
+ async def _execute():
264
+ result, metadata = await self.one_rollout(agent_run)
265
+ if result is not None:
266
+ indep_results.append(result)
267
+ indep_rollout_metadata.append(metadata)
268
+
269
+ # Run rollouts concurrently
270
+ async with anyio.create_task_group() as tg:
271
+ for _ in range(self.cfg.n_rollouts_per_input):
272
+ tg.start_soon(_execute)
69
273
  if not indep_results:
70
274
  return None
71
275
 
@@ -79,7 +283,7 @@ class MajorityVotingJudge(BaseJudge):
79
283
  final_output = indep_results[final_max_idx]
80
284
 
81
285
  # Compute the distribution of the output across the agreement keys
82
- final_output_distribution = compute_output_distribution(
286
+ final_output_distributions = compute_output_distributions(
83
287
  indep_results, self.cfg.output_schema, agreement_keys
84
288
  )
85
289
 
@@ -94,54 +298,122 @@ class MajorityVotingJudge(BaseJudge):
94
298
  "final_results": indep_results,
95
299
  "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
96
300
  "final_max_idx": final_max_idx,
97
- "final_output_distribution": final_output_distribution,
301
+ "final_output_distributions": final_output_distributions,
302
+ "final_rollout_metadata": indep_rollout_metadata,
98
303
  },
99
304
  result_type=ResultType.DIRECT_RESULT,
100
305
  )
101
306
 
307
+ async def estimate_output_distrs(
308
+ self, agent_run: AgentRun, *, n_initial_rollouts_to_sample: int, **kwargs: Any
309
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
310
+ if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
311
+ raise ValueError(
312
+ "n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
313
+ )
314
+
315
+ indep_results: list[dict[str, Any]] = []
316
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
317
+ pbar = tqdm(total=n_initial_rollouts_to_sample, desc="Independent rollouts", leave=False)
318
+
319
+ async def _execute():
320
+ result, metadata = await self.one_rollout(agent_run)
321
+ if result is not None:
322
+ indep_results.append(result)
323
+ indep_rollout_metadata.append(metadata)
324
+ pbar.update(1)
325
+
326
+ # Run rollouts concurrently
327
+ async with anyio.create_task_group() as tg:
328
+ for _ in range(n_initial_rollouts_to_sample):
329
+ tg.start_soon(_execute)
330
+
331
+ pbar.close()
332
+
333
+ if not indep_results:
334
+ return None
335
+
336
+ # Compute the probability vector for each agreement key
337
+ distributions = compute_output_distributions(
338
+ indep_results, self.cfg.output_schema, get_agreement_keys(self.cfg.output_schema)
339
+ )
340
+
341
+ return distributions, {
342
+ "first_step_rollouts": indep_results,
343
+ "first_step_rollout_metadata": indep_rollout_metadata,
344
+ }
345
+
102
346
 
103
347
  class MultiReflectionJudge(BaseJudge):
104
348
  """Rolls out the judge multiple times, then uses reflection to determine the final result."""
105
349
 
106
350
  def __init__(
107
- self,
108
- cfg: Rubric,
109
- n_rollouts_per_input: int,
110
- llm_svc: BaseLLMService = SimpleLLMService(),
351
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
111
352
  ):
112
- super().__init__(cfg, llm_svc)
113
- self.n_rollouts_per_input = n_rollouts_per_input
114
-
115
- async def __call__(
116
- self,
117
- agent_run: AgentRun,
118
- max_concurrency: int = 10,
119
- ) -> JudgeResult | None:
120
- rubric = self.cfg
121
-
122
- async def _validation_callback(batch_index: int, llm_output: LLMOutput):
123
- parse_and_validate_llm_output(llm_output, rubric.output_schema, agent_run)
124
-
125
- # Run several independent rollouts
126
- prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
353
+ super().__init__(cfg, llm_svc, docent_collection_id)
354
+
355
+ async def one_rollout_second_stage(
356
+ self, agent_run: AgentRun, first_stage_results: list[dict[str, Any]]
357
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
358
+ """Reflect on the results of the first stage of rollouts.
359
+ TODO(mengk): this is only done in a single-turn way. We should generalize this to multi-turn.
360
+ """
361
+
362
+ # Construct *single* reflection prompt
363
+ first_stage_results_text = "\n\n".join(
364
+ [
365
+ f"Rollout {j+1}:\n{yaml.dump(r, width=float('inf'))}"
366
+ for j, r in enumerate(first_stage_results)
367
+ ]
368
+ )
369
+ reflection_instruction = (
370
+ f"We have sampled a judge {len(first_stage_results)} times to get {len(first_stage_results)} independent answers to the same rubric evaluation:\n"
371
+ f"{first_stage_results_text}\n\n"
372
+ f"Please reflect on these answers. Consider all the information and evidence presented. "
373
+ f"Return a final answer in the same JSON format as before."
374
+ )
375
+ reflection_prompt = [
376
+ # Original system prompt
377
+ {"role": "system", "content": self.cfg.materialize_system_prompt(agent_run)},
378
+ # Additional reflection instruction as a user message (kind of awkward)
379
+ {"role": "user", "content": reflection_instruction},
380
+ ]
381
+
382
+ # Ask the judge to reflect on the others' results
127
383
  outputs = await self.llm_svc.get_completions(
128
- inputs=[prompt for _ in range(self.n_rollouts_per_input)],
129
- model_options=[rubric.judge_model],
384
+ inputs=[reflection_prompt],
385
+ model_options=[self.cfg.judge_model],
130
386
  max_new_tokens=16384,
131
387
  timeout=180.0,
132
388
  use_cache=False,
133
- validation_callback=_validation_callback,
134
- max_concurrency=max_concurrency,
389
+ validation_callback=self._get_validation_callback(agent_run),
135
390
  )
391
+ output_str = outputs[0].first_text
136
392
 
137
- # Process each rollout
138
- indep_results: list[dict[str, Any]] = []
139
- for output in outputs:
140
- if output.first_text is None:
141
- continue
142
- if v_output := parse_and_validate_llm_output(output, rubric.output_schema, agent_run):
143
- indep_results.append(v_output)
393
+ validated_output = self._validate_first_response_tag_or_entire_output(
394
+ output_str or "", agent_run
395
+ )
396
+ if validated_output is not None:
397
+ return validated_output, None
398
+ else:
399
+ return None, None
144
400
 
401
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
402
+ rubric = self.cfg
403
+
404
+ indep_results: list[dict[str, Any]] = []
405
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
406
+
407
+ async def _execute():
408
+ result, metadata = await self.one_rollout(agent_run)
409
+ if result is not None:
410
+ indep_results.append(result)
411
+ indep_rollout_metadata.append(metadata)
412
+
413
+ # Stage 1: run rollouts concurrently
414
+ async with anyio.create_task_group() as tg:
415
+ for _ in range(self.cfg.n_rollouts_per_input):
416
+ tg.start_soon(_execute)
145
417
  if not indep_results:
146
418
  return None
147
419
 
@@ -151,61 +423,28 @@ class MultiReflectionJudge(BaseJudge):
151
423
  indep_results, agreement_keys
152
424
  )
153
425
 
154
- def _get_reflection_prompt(cur_index: int):
155
- # Current result
156
- result = indep_results[cur_index]
157
- # Get other results (excluding the current one)
158
- other_results = [r for j, r in enumerate(indep_results) if j != cur_index]
159
-
160
- # Create the reflection message
161
- other_results_text = "\n\n".join(
162
- [f"Answer {j+1}:\n{json.dumps(r, indent=2)}" for j, r in enumerate(other_results)]
163
- )
164
-
165
- reflection_instruction = (
166
- f"Here are {len(other_results)} other independent answers to the same rubric evaluation:\n\n"
167
- f"{other_results_text}\n\n"
168
- f"Please reflect on these other answers and your own answer. "
169
- f"Consider if any of them have identified important aspects you missed, or if there are disagreements that should be resolved. "
170
- f"Then provide your final answer in the same JSON format as before."
171
- )
172
-
173
- # Construct the multi-message prompt
174
- # 1. Original user message
175
- # 2. Assistant message with the rollout's result
176
- # 3. New user message asking for reflection
177
- return [
178
- *prompt, # Original user message(s)
179
- {"role": "assistant", "content": json.dumps(result, indent=2)},
180
- {"role": "user", "content": reflection_instruction},
181
- ]
182
-
183
- final_results = indep_results.copy() # Shallow copy
426
+ # Stage 2: reflect on the results
427
+ # Shallow copies are fine
428
+ final_results = indep_results.copy()
429
+ final_rollout_metadata = indep_rollout_metadata.copy()
184
430
  if len(indep_results) > 1:
185
- # Ask the judge to reflect on the others' results
186
- reflection_outputs = await self.llm_svc.get_completions(
187
- inputs=[_get_reflection_prompt(i) for i in range(len(indep_results))],
188
- model_options=[rubric.judge_model],
189
- max_new_tokens=16384,
190
- timeout=180.0,
191
- use_cache=False,
192
- validation_callback=_validation_callback,
193
- max_concurrency=max_concurrency,
194
- )
431
+ candidate_final_results: list[dict[str, Any]] = []
432
+ candidate_final_rollout_metadata: list[dict[str, Any] | None] = []
433
+
434
+ async def _execute_second_stage():
435
+ result, metadata = await self.one_rollout_second_stage(agent_run, indep_results)
436
+ if result is not None:
437
+ candidate_final_results.append(result)
438
+ candidate_final_rollout_metadata.append(metadata)
195
439
 
196
- # Process reflection outputs in the same way as the initial rollouts
197
- reflected_results: list[dict[str, Any]] = []
198
- for output in reflection_outputs:
199
- if output.first_text is None:
200
- continue
201
- if v_output := parse_and_validate_llm_output(
202
- output, rubric.output_schema, agent_run
203
- ):
204
- reflected_results.append(v_output)
440
+ async with anyio.create_task_group() as tg:
441
+ for _ in range(self.cfg.n_rollouts_per_input):
442
+ tg.start_soon(_execute_second_stage)
205
443
 
206
444
  # Use reflected results if we got any, otherwise fall back to original results
207
- if reflected_results:
208
- final_results = reflected_results
445
+ if candidate_final_results:
446
+ final_results = candidate_final_results
447
+ final_rollout_metadata = candidate_final_rollout_metadata
209
448
  else:
210
449
  logger.warning("No reflected results found, falling back to original results")
211
450
 
@@ -223,10 +462,126 @@ class MultiReflectionJudge(BaseJudge):
223
462
  "final_results": final_results,
224
463
  "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
225
464
  "final_max_idx": final_max_idx,
465
+ "final_rollout_metadata": final_rollout_metadata,
226
466
  # Also include initial measurements
227
467
  "indep_results": indep_results,
228
468
  "indep_max_idx": indep_max_idx,
229
469
  "indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
470
+ "indep_rollout_metadata": indep_rollout_metadata,
230
471
  },
231
472
  result_type=ResultType.DIRECT_RESULT,
232
473
  )
474
+
475
+ async def estimate_output_distrs(
476
+ self,
477
+ agent_run: AgentRun,
478
+ *,
479
+ n_initial_rollouts_to_sample: int,
480
+ n_combinations_to_sample: int,
481
+ n_reflection_rollouts_to_sample: int,
482
+ **kwargs: Any,
483
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
484
+ if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
485
+ raise ValueError(
486
+ "n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
487
+ )
488
+ if self.cfg.n_rollouts_per_input > n_reflection_rollouts_to_sample:
489
+ raise ValueError(
490
+ "n_reflection_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
491
+ )
492
+
493
+ first_step_rollouts: list[dict[str, Any]] = []
494
+ first_step_rollout_metadata: list[dict[str, Any] | None] = []
495
+ first_step_combinations: list[list[dict[str, Any]]] = []
496
+ second_step_rollouts: list[list[dict[str, Any]]] = []
497
+ second_step_rollout_metadata: list[list[dict[str, Any] | None]] = []
498
+
499
+ ##########
500
+ # Step 1 #
501
+ ##########
502
+
503
+ pbar_first = tqdm(
504
+ total=n_initial_rollouts_to_sample, desc="Stage 1: Initial rollouts", leave=False
505
+ )
506
+
507
+ async def _execute_first_stage():
508
+ result, metadata = await self.one_rollout(agent_run)
509
+ if result is not None:
510
+ first_step_rollouts.append(result)
511
+ first_step_rollout_metadata.append(metadata)
512
+ pbar_first.update(1)
513
+
514
+ # Collect rollouts of the first stage
515
+ async with anyio.create_task_group() as tg_first:
516
+ for _ in range(n_initial_rollouts_to_sample):
517
+ tg_first.start_soon(_execute_first_stage)
518
+
519
+ pbar_first.close()
520
+
521
+ if len(first_step_rollouts) < self.cfg.n_rollouts_per_input:
522
+ raise ValueError("Not enough first step rollouts to sample combinations")
523
+
524
+ # Sample random k-sized combinations of the first step rollouts
525
+ for _ in range(n_combinations_to_sample):
526
+ combination = random.sample(first_step_rollouts, self.cfg.n_rollouts_per_input)
527
+ first_step_combinations.append(combination)
528
+ second_step_rollouts.append([])
529
+ second_step_rollout_metadata.append([])
530
+
531
+ ##########
532
+ # Step 2 #
533
+ ##########
534
+
535
+ pbar_second = tqdm(
536
+ total=n_combinations_to_sample, desc="Stage 2: Combinations", leave=False
537
+ )
538
+
539
+ async with anyio.create_task_group() as tg_second:
540
+
541
+ async def _execute_second_stage(i: int, combination: list[dict[str, Any]]):
542
+ pbar_third = tqdm(
543
+ total=n_reflection_rollouts_to_sample,
544
+ desc=f"Stage 2: Combination {i+1}/{n_combinations_to_sample}",
545
+ leave=False,
546
+ )
547
+
548
+ async def _execute_second_stage_inner():
549
+ result, metadata = await self.one_rollout_second_stage(agent_run, combination)
550
+ if result is not None:
551
+ second_step_rollouts[i].append(result)
552
+ second_step_rollout_metadata[i].append(metadata)
553
+ pbar_third.update(1)
554
+
555
+ async with anyio.create_task_group() as tg:
556
+ for _ in range(n_reflection_rollouts_to_sample):
557
+ tg.start_soon(_execute_second_stage_inner)
558
+
559
+ pbar_third.close()
560
+ pbar_second.update(1)
561
+
562
+ for i, combination in enumerate(first_step_combinations):
563
+ tg_second.start_soon(_execute_second_stage, i, combination)
564
+
565
+ pbar_second.close()
566
+
567
+ output_distributions = compute_output_distributions(
568
+ [sublist for el in second_step_rollouts for sublist in el],
569
+ self.cfg.output_schema,
570
+ get_agreement_keys(self.cfg.output_schema),
571
+ )
572
+
573
+ return output_distributions, {
574
+ "first_step_rollouts": first_step_rollouts,
575
+ "first_step_rollout_metadata": first_step_rollout_metadata,
576
+ "first_step_combinations": first_step_combinations,
577
+ "second_step_rollouts": second_step_rollouts,
578
+ "second_step_rollout_metadata": second_step_rollout_metadata,
579
+ }
580
+
581
+
582
+ def build_judge(rubric: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None):
583
+ if rubric.judge_variant == JudgeVariant.MAJORITY:
584
+ return MajorityVotingJudge(rubric, llm_svc, docent_collection_id)
585
+ elif rubric.judge_variant == JudgeVariant.MULTI_REFLECT:
586
+ return MultiReflectionJudge(rubric, llm_svc, docent_collection_id)
587
+ raise ValueError(f"Invalid variant: {rubric.judge_variant}")