docent-python 0.1.19a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (38) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +331 -0
  5. docent/_llm_util/llm_cache.py +193 -0
  6. docent/_llm_util/llm_svc.py +472 -0
  7. docent/_llm_util/model_registry.py +130 -0
  8. docent/_llm_util/providers/__init__.py +0 -0
  9. docent/_llm_util/providers/anthropic.py +537 -0
  10. docent/_llm_util/providers/common.py +41 -0
  11. docent/_llm_util/providers/google.py +530 -0
  12. docent/_llm_util/providers/openai.py +745 -0
  13. docent/_llm_util/providers/openrouter.py +375 -0
  14. docent/_llm_util/providers/preference_types.py +104 -0
  15. docent/_llm_util/providers/provider_registry.py +164 -0
  16. docent/data_models/__init__.py +2 -2
  17. docent/data_models/agent_run.py +1 -0
  18. docent/data_models/judge.py +7 -4
  19. docent/data_models/transcript.py +2 -0
  20. docent/data_models/util.py +170 -0
  21. docent/judges/__init__.py +23 -0
  22. docent/judges/analysis.py +77 -0
  23. docent/judges/impl.py +587 -0
  24. docent/judges/runner.py +129 -0
  25. docent/judges/stats.py +205 -0
  26. docent/judges/types.py +311 -0
  27. docent/judges/util/forgiving_json.py +108 -0
  28. docent/judges/util/meta_schema.json +86 -0
  29. docent/judges/util/meta_schema.py +29 -0
  30. docent/judges/util/parse_output.py +87 -0
  31. docent/judges/util/voting.py +139 -0
  32. docent/sdk/client.py +181 -44
  33. docent/trace.py +362 -44
  34. {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
  35. docent_python-0.1.27a0.dist-info/RECORD +59 -0
  36. docent_python-0.1.19a0.dist-info/RECORD +0 -32
  37. {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
  38. {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
docent/judges/impl.py ADDED
@@ -0,0 +1,587 @@
1
+ import random
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from contextlib import nullcontext
5
+ from typing import Any, Sequence
6
+
7
+ import anyio
8
+ import yaml
9
+ from pydantic_core import to_jsonable_python
10
+ from tqdm.auto import tqdm
11
+
12
+ from docent._llm_util.data_models.exceptions import ValidationFailedException
13
+ from docent._llm_util.data_models.llm_output import LLMOutput
14
+ from docent._llm_util.llm_svc import BaseLLMService
15
+ from docent._log_util import get_logger
16
+ from docent.data_models.agent_run import AgentRun
17
+ from docent.data_models.chat.message import (
18
+ AssistantMessage,
19
+ ChatMessage,
20
+ ToolMessage,
21
+ UserMessage,
22
+ )
23
+ from docent.data_models.chat.tool import ToolInfo
24
+ from docent.judges.types import JudgeResult, JudgeVariant, ResultType, Rubric
25
+ from docent.judges.util.parse_output import parse_and_validate_output_str
26
+ from docent.judges.util.voting import (
27
+ JudgeOutputDistribution,
28
+ compute_output_distributions,
29
+ find_modal_result,
30
+ get_agreement_keys,
31
+ )
32
+ from docent.trace import agent_run_context, agent_run_metadata
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class BaseJudge(ABC):
38
+ def __init__(
39
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
40
+ ):
41
+ self.cfg = cfg
42
+ self.llm_svc = llm_svc
43
+ self.docent_collection_id = docent_collection_id
44
+
45
+ @abstractmethod
46
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
47
+ """Returns None if all rollouts failed to produce a valid output."""
48
+
49
+ @abstractmethod
50
+ async def estimate_output_distrs(
51
+ self, agent_run: AgentRun, **kwargs: Any
52
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
53
+ """Estimate the output distribution of each output key."""
54
+
55
+ def _get_validation_callback(self, agent_run: AgentRun):
56
+ async def _validation_callback(batch_index: int, llm_output: LLMOutput):
57
+ validated_output = self._validate_first_response_tag_or_entire_output(
58
+ llm_output.first_text or "", agent_run
59
+ )
60
+ if validated_output is None:
61
+ raise ValidationFailedException(
62
+ "Validation failed", failed_output=llm_output.first_text
63
+ )
64
+
65
+ return _validation_callback
66
+
67
+ async def one_rollout(
68
+ self, agent_run: AgentRun
69
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
70
+ with agent_run_context() if self.docent_collection_id is not None else nullcontext():
71
+ if self.cfg.rollout_type == "single_turn":
72
+ output, metadata = await self.one_single_turn_rollout(agent_run)
73
+ elif self.cfg.rollout_type == "multi_turn":
74
+ output, metadata = await self.one_multi_turn_rollout(
75
+ agent_run, max_turns=10, max_steps_per_turn=5
76
+ )
77
+ else:
78
+ raise ValueError(f"Invalid rollout type: {self.cfg.rollout_type}")
79
+
80
+ if self.docent_collection_id is not None:
81
+ agent_run_metadata(
82
+ {
83
+ "agent_run_id": agent_run.id,
84
+ "judge_output": output,
85
+ "judge_rollout_metadata": to_jsonable_python(metadata),
86
+ }
87
+ )
88
+
89
+ return output, metadata
90
+
91
+ def _validate_first_response_tag_or_entire_output(
92
+ self, output_str: str, agent_run: AgentRun
93
+ ) -> dict[str, Any] | None:
94
+ """Validate the first <response> tag in the output string.
95
+ For backward compatibility, also try to validate the entire output as JSON, for
96
+ old system prompts that don't ask for <response> tags.
97
+
98
+ Args:
99
+ output_str: The output string to validate
100
+ agent_run: The agent run to validate against
101
+
102
+ Returns:
103
+ The validated output if successful, None otherwise
104
+ """
105
+ response_matches = re.findall(r"<response>(.*?)</response>", output_str, re.DOTALL)
106
+
107
+ # Try to validate any match; take the first
108
+ for response_text in response_matches:
109
+ try:
110
+ validated_output = parse_and_validate_output_str(
111
+ response_text, self.cfg.output_schema, agent_run
112
+ )
113
+ return validated_output
114
+ except ValidationFailedException:
115
+ continue # Try the next match if validation fails
116
+
117
+ # Try to validate the entire output as JSON
118
+ # But only if the output _didn't_ contain a <response>...</response> tag
119
+ if not response_matches:
120
+ try:
121
+ validated_output = parse_and_validate_output_str(
122
+ output_str, self.cfg.output_schema, agent_run
123
+ )
124
+ return validated_output
125
+ except ValidationFailedException:
126
+ pass
127
+
128
+ return None
129
+
130
+ ########################
131
+ # Single turn rollouts #
132
+ ########################
133
+
134
+ async def one_single_turn_rollout(
135
+ self, agent_run: AgentRun
136
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
137
+ prompt = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
138
+ outputs = await self.llm_svc.get_completions(
139
+ inputs=[prompt],
140
+ model_options=[self.cfg.judge_model],
141
+ max_new_tokens=16384,
142
+ timeout=180.0,
143
+ use_cache=False,
144
+ validation_callback=self._get_validation_callback(agent_run),
145
+ )
146
+ output_str = outputs[0].first_text
147
+
148
+ # Extract all <response>...</response> tags from the current message
149
+ validated_output = self._validate_first_response_tag_or_entire_output(
150
+ output_str or "", agent_run
151
+ )
152
+ if validated_output is not None:
153
+ return validated_output, {"full_output": output_str}
154
+ else:
155
+ return None, None
156
+
157
+ #######################
158
+ # Multi-turn rollouts #
159
+ #######################
160
+
161
+ async def one_multi_turn_rollout(
162
+ self, agent_run: AgentRun, max_turns: int, max_steps_per_turn: int
163
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
164
+ msgs = [UserMessage(content=self.cfg.materialize_system_prompt(agent_run))]
165
+ for _ in range(max_turns):
166
+ msgs = await self.agent_one_turn(msgs, max_steps_per_turn=max_steps_per_turn)
167
+
168
+ last_msg_content = msgs[-1].text if msgs else None
169
+ # Extract all <response>...</response> tags from the current message
170
+ # Return if we find a valid response; otherwise, continue
171
+ validated_output = self._validate_first_response_tag_or_entire_output(
172
+ last_msg_content or "", agent_run
173
+ )
174
+ if validated_output is not None:
175
+ # When returning, strip out the system message, which duplicates the agent run
176
+ # content many times.
177
+ return validated_output, {"rollout_messages": msgs[1:]}
178
+
179
+ # No <response>...</response> tags with valid JSON,so return None
180
+ return None, None
181
+
182
+ async def agent_one_turn(self, init_msgs: Sequence[ChatMessage], max_steps_per_turn: int):
183
+ """Given a list of messages, run one turn of the agent.
184
+ The agent may invoke tools, so we loop until there are no more to handle.
185
+ """
186
+
187
+ msgs = list(init_msgs) # Shallow copy is fine
188
+ for _ in range(max_steps_per_turn):
189
+ last_msg = msgs[-1]
190
+ if last_msg.role == "system" or last_msg.role == "user" or last_msg.role == "tool":
191
+ outputs = await self.llm_svc.get_completions(
192
+ inputs=[msgs],
193
+ model_options=[self.cfg.judge_model],
194
+ tools=[
195
+ ToolInfo(
196
+ name="step_finished",
197
+ description="Call this tool to indicate that you have finished one step in the decision procedure",
198
+ )
199
+ ],
200
+ max_new_tokens=16384,
201
+ timeout=180.0,
202
+ use_cache=False,
203
+ )
204
+ output = outputs[0].first
205
+ if output is None:
206
+ # FIXME(mengk): handle empty completion
207
+ raise ValueError("Empty completion in agent one turn")
208
+ new_assistant_msg = AssistantMessage(
209
+ content=output.text or "", tool_calls=output.tool_calls
210
+ )
211
+ msgs.append(new_assistant_msg)
212
+ elif last_msg.role == "assistant":
213
+ if last_msg.tool_calls is not None:
214
+ msgs.extend(
215
+ [
216
+ ToolMessage(
217
+ content="Step completed",
218
+ tool_call_id=tool_call.id,
219
+ )
220
+ for tool_call in last_msg.tool_calls
221
+ ]
222
+ )
223
+ else:
224
+ break # Terminate if there are no more tool calls to handle
225
+ else:
226
+ raise ValueError(f"Unknown message role: {last_msg.role}")
227
+ return msgs
228
+
229
+
230
+ class SingleRolloutJudge(BaseJudge):
231
+ """Rolls out the judge once."""
232
+
233
+ def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
234
+ super().__init__(cfg, llm_svc)
235
+
236
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
237
+ output, metadata = await self.one_rollout(agent_run)
238
+ if output is None:
239
+ return None
240
+ else:
241
+ return JudgeResult(
242
+ agent_run_id=agent_run.id,
243
+ rubric_id=self.cfg.id,
244
+ rubric_version=self.cfg.version,
245
+ output=output,
246
+ result_metadata={"rollout_metadata": metadata},
247
+ result_type=ResultType.DIRECT_RESULT,
248
+ )
249
+
250
+
251
+ class MajorityVotingJudge(BaseJudge):
252
+ """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
253
+
254
+ def __init__(
255
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
256
+ ):
257
+ super().__init__(cfg, llm_svc, docent_collection_id)
258
+
259
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
260
+ indep_results: list[dict[str, Any]] = []
261
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
262
+
263
+ async def _execute():
264
+ result, metadata = await self.one_rollout(agent_run)
265
+ if result is not None:
266
+ indep_results.append(result)
267
+ indep_rollout_metadata.append(metadata)
268
+
269
+ # Run rollouts concurrently
270
+ async with anyio.create_task_group() as tg:
271
+ for _ in range(self.cfg.n_rollouts_per_input):
272
+ tg.start_soon(_execute)
273
+ if not indep_results:
274
+ return None
275
+
276
+ # Get a list of the keys that we want to measure agreement on
277
+ agreement_keys = get_agreement_keys(self.cfg.output_schema)
278
+
279
+ # Find the result that best matches modal values
280
+ final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
281
+ indep_results, agreement_keys
282
+ )
283
+ final_output = indep_results[final_max_idx]
284
+
285
+ # Compute the distribution of the output across the agreement keys
286
+ final_output_distributions = compute_output_distributions(
287
+ indep_results, self.cfg.output_schema, agreement_keys
288
+ )
289
+
290
+ return JudgeResult(
291
+ agent_run_id=agent_run.id,
292
+ rubric_id=self.cfg.id,
293
+ rubric_version=self.cfg.version,
294
+ output=final_output,
295
+ result_metadata={
296
+ "agt_keys": agreement_keys,
297
+ # Final measurements
298
+ "final_results": indep_results,
299
+ "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
300
+ "final_max_idx": final_max_idx,
301
+ "final_output_distributions": final_output_distributions,
302
+ "final_rollout_metadata": indep_rollout_metadata,
303
+ },
304
+ result_type=ResultType.DIRECT_RESULT,
305
+ )
306
+
307
+ async def estimate_output_distrs(
308
+ self, agent_run: AgentRun, *, n_initial_rollouts_to_sample: int, **kwargs: Any
309
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
310
+ if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
311
+ raise ValueError(
312
+ "n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
313
+ )
314
+
315
+ indep_results: list[dict[str, Any]] = []
316
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
317
+ pbar = tqdm(total=n_initial_rollouts_to_sample, desc="Independent rollouts", leave=False)
318
+
319
+ async def _execute():
320
+ result, metadata = await self.one_rollout(agent_run)
321
+ if result is not None:
322
+ indep_results.append(result)
323
+ indep_rollout_metadata.append(metadata)
324
+ pbar.update(1)
325
+
326
+ # Run rollouts concurrently
327
+ async with anyio.create_task_group() as tg:
328
+ for _ in range(n_initial_rollouts_to_sample):
329
+ tg.start_soon(_execute)
330
+
331
+ pbar.close()
332
+
333
+ if not indep_results:
334
+ return None
335
+
336
+ # Compute the probability vector for each agreement key
337
+ distributions = compute_output_distributions(
338
+ indep_results, self.cfg.output_schema, get_agreement_keys(self.cfg.output_schema)
339
+ )
340
+
341
+ return distributions, {
342
+ "first_step_rollouts": indep_results,
343
+ "first_step_rollout_metadata": indep_rollout_metadata,
344
+ }
345
+
346
+
347
+ class MultiReflectionJudge(BaseJudge):
348
+ """Rolls out the judge multiple times, then uses reflection to determine the final result."""
349
+
350
+ def __init__(
351
+ self, cfg: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None
352
+ ):
353
+ super().__init__(cfg, llm_svc, docent_collection_id)
354
+
355
+ async def one_rollout_second_stage(
356
+ self, agent_run: AgentRun, first_stage_results: list[dict[str, Any]]
357
+ ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
358
+ """Reflect on the results of the first stage of rollouts.
359
+ TODO(mengk): this is only done in a single-turn way. We should generalize this to multi-turn.
360
+ """
361
+
362
+ # Construct *single* reflection prompt
363
+ first_stage_results_text = "\n\n".join(
364
+ [
365
+ f"Rollout {j+1}:\n{yaml.dump(r, width=float('inf'))}"
366
+ for j, r in enumerate(first_stage_results)
367
+ ]
368
+ )
369
+ reflection_instruction = (
370
+ f"We have sampled a judge {len(first_stage_results)} times to get {len(first_stage_results)} independent answers to the same rubric evaluation:\n"
371
+ f"{first_stage_results_text}\n\n"
372
+ f"Please reflect on these answers. Consider all the information and evidence presented. "
373
+ f"Return a final answer in the same JSON format as before."
374
+ )
375
+ reflection_prompt = [
376
+ # Original system prompt
377
+ {"role": "system", "content": self.cfg.materialize_system_prompt(agent_run)},
378
+ # Additional reflection instruction as a user message (kind of awkward)
379
+ {"role": "user", "content": reflection_instruction},
380
+ ]
381
+
382
+ # Ask the judge to reflect on the others' results
383
+ outputs = await self.llm_svc.get_completions(
384
+ inputs=[reflection_prompt],
385
+ model_options=[self.cfg.judge_model],
386
+ max_new_tokens=16384,
387
+ timeout=180.0,
388
+ use_cache=False,
389
+ validation_callback=self._get_validation_callback(agent_run),
390
+ )
391
+ output_str = outputs[0].first_text
392
+
393
+ validated_output = self._validate_first_response_tag_or_entire_output(
394
+ output_str or "", agent_run
395
+ )
396
+ if validated_output is not None:
397
+ return validated_output, None
398
+ else:
399
+ return None, None
400
+
401
+ async def __call__(self, agent_run: AgentRun) -> JudgeResult | None:
402
+ rubric = self.cfg
403
+
404
+ indep_results: list[dict[str, Any]] = []
405
+ indep_rollout_metadata: list[dict[str, Any] | None] = []
406
+
407
+ async def _execute():
408
+ result, metadata = await self.one_rollout(agent_run)
409
+ if result is not None:
410
+ indep_results.append(result)
411
+ indep_rollout_metadata.append(metadata)
412
+
413
+ # Stage 1: run rollouts concurrently
414
+ async with anyio.create_task_group() as tg:
415
+ for _ in range(self.cfg.n_rollouts_per_input):
416
+ tg.start_soon(_execute)
417
+ if not indep_results:
418
+ return None
419
+
420
+ # Compute initial modes
421
+ agreement_keys = get_agreement_keys(rubric.output_schema)
422
+ indep_max_idx, indep_agt_key_modes_and_counts = find_modal_result(
423
+ indep_results, agreement_keys
424
+ )
425
+
426
+ # Stage 2: reflect on the results
427
+ # Shallow copies are fine
428
+ final_results = indep_results.copy()
429
+ final_rollout_metadata = indep_rollout_metadata.copy()
430
+ if len(indep_results) > 1:
431
+ candidate_final_results: list[dict[str, Any]] = []
432
+ candidate_final_rollout_metadata: list[dict[str, Any] | None] = []
433
+
434
+ async def _execute_second_stage():
435
+ result, metadata = await self.one_rollout_second_stage(agent_run, indep_results)
436
+ if result is not None:
437
+ candidate_final_results.append(result)
438
+ candidate_final_rollout_metadata.append(metadata)
439
+
440
+ async with anyio.create_task_group() as tg:
441
+ for _ in range(self.cfg.n_rollouts_per_input):
442
+ tg.start_soon(_execute_second_stage)
443
+
444
+ # Use reflected results if we got any, otherwise fall back to original results
445
+ if candidate_final_results:
446
+ final_results = candidate_final_results
447
+ final_rollout_metadata = candidate_final_rollout_metadata
448
+ else:
449
+ logger.warning("No reflected results found, falling back to original results")
450
+
451
+ final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
452
+ final_results, agreement_keys
453
+ )
454
+ return JudgeResult(
455
+ agent_run_id=agent_run.id,
456
+ rubric_id=rubric.id,
457
+ rubric_version=rubric.version,
458
+ output=final_results[final_max_idx],
459
+ result_metadata={
460
+ "agt_keys": agreement_keys,
461
+ # Final measurements
462
+ "final_results": final_results,
463
+ "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
464
+ "final_max_idx": final_max_idx,
465
+ "final_rollout_metadata": final_rollout_metadata,
466
+ # Also include initial measurements
467
+ "indep_results": indep_results,
468
+ "indep_max_idx": indep_max_idx,
469
+ "indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
470
+ "indep_rollout_metadata": indep_rollout_metadata,
471
+ },
472
+ result_type=ResultType.DIRECT_RESULT,
473
+ )
474
+
475
+ async def estimate_output_distrs(
476
+ self,
477
+ agent_run: AgentRun,
478
+ *,
479
+ n_initial_rollouts_to_sample: int,
480
+ n_combinations_to_sample: int,
481
+ n_reflection_rollouts_to_sample: int,
482
+ **kwargs: Any,
483
+ ) -> None | tuple[dict[str, JudgeOutputDistribution], dict[str, Any]]:
484
+ if self.cfg.n_rollouts_per_input > n_initial_rollouts_to_sample:
485
+ raise ValueError(
486
+ "n_initial_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
487
+ )
488
+ if self.cfg.n_rollouts_per_input > n_reflection_rollouts_to_sample:
489
+ raise ValueError(
490
+ "n_reflection_rollouts_to_sample must be greater than or equal to cfg.n_rollouts_per_input"
491
+ )
492
+
493
+ first_step_rollouts: list[dict[str, Any]] = []
494
+ first_step_rollout_metadata: list[dict[str, Any] | None] = []
495
+ first_step_combinations: list[list[dict[str, Any]]] = []
496
+ second_step_rollouts: list[list[dict[str, Any]]] = []
497
+ second_step_rollout_metadata: list[list[dict[str, Any] | None]] = []
498
+
499
+ ##########
500
+ # Step 1 #
501
+ ##########
502
+
503
+ pbar_first = tqdm(
504
+ total=n_initial_rollouts_to_sample, desc="Stage 1: Initial rollouts", leave=False
505
+ )
506
+
507
+ async def _execute_first_stage():
508
+ result, metadata = await self.one_rollout(agent_run)
509
+ if result is not None:
510
+ first_step_rollouts.append(result)
511
+ first_step_rollout_metadata.append(metadata)
512
+ pbar_first.update(1)
513
+
514
+ # Collect rollouts of the first stage
515
+ async with anyio.create_task_group() as tg_first:
516
+ for _ in range(n_initial_rollouts_to_sample):
517
+ tg_first.start_soon(_execute_first_stage)
518
+
519
+ pbar_first.close()
520
+
521
+ if len(first_step_rollouts) < self.cfg.n_rollouts_per_input:
522
+ raise ValueError("Not enough first step rollouts to sample combinations")
523
+
524
+ # Sample random k-sized combinations of the first step rollouts
525
+ for _ in range(n_combinations_to_sample):
526
+ combination = random.sample(first_step_rollouts, self.cfg.n_rollouts_per_input)
527
+ first_step_combinations.append(combination)
528
+ second_step_rollouts.append([])
529
+ second_step_rollout_metadata.append([])
530
+
531
+ ##########
532
+ # Step 2 #
533
+ ##########
534
+
535
+ pbar_second = tqdm(
536
+ total=n_combinations_to_sample, desc="Stage 2: Combinations", leave=False
537
+ )
538
+
539
+ async with anyio.create_task_group() as tg_second:
540
+
541
+ async def _execute_second_stage(i: int, combination: list[dict[str, Any]]):
542
+ pbar_third = tqdm(
543
+ total=n_reflection_rollouts_to_sample,
544
+ desc=f"Stage 2: Combination {i+1}/{n_combinations_to_sample}",
545
+ leave=False,
546
+ )
547
+
548
+ async def _execute_second_stage_inner():
549
+ result, metadata = await self.one_rollout_second_stage(agent_run, combination)
550
+ if result is not None:
551
+ second_step_rollouts[i].append(result)
552
+ second_step_rollout_metadata[i].append(metadata)
553
+ pbar_third.update(1)
554
+
555
+ async with anyio.create_task_group() as tg:
556
+ for _ in range(n_reflection_rollouts_to_sample):
557
+ tg.start_soon(_execute_second_stage_inner)
558
+
559
+ pbar_third.close()
560
+ pbar_second.update(1)
561
+
562
+ for i, combination in enumerate(first_step_combinations):
563
+ tg_second.start_soon(_execute_second_stage, i, combination)
564
+
565
+ pbar_second.close()
566
+
567
+ output_distributions = compute_output_distributions(
568
+ [sublist for el in second_step_rollouts for sublist in el],
569
+ self.cfg.output_schema,
570
+ get_agreement_keys(self.cfg.output_schema),
571
+ )
572
+
573
+ return output_distributions, {
574
+ "first_step_rollouts": first_step_rollouts,
575
+ "first_step_rollout_metadata": first_step_rollout_metadata,
576
+ "first_step_combinations": first_step_combinations,
577
+ "second_step_rollouts": second_step_rollouts,
578
+ "second_step_rollout_metadata": second_step_rollout_metadata,
579
+ }
580
+
581
+
582
+ def build_judge(rubric: Rubric, llm_svc: BaseLLMService, docent_collection_id: str | None = None):
583
+ if rubric.judge_variant == JudgeVariant.MAJORITY:
584
+ return MajorityVotingJudge(rubric, llm_svc, docent_collection_id)
585
+ elif rubric.judge_variant == JudgeVariant.MULTI_REFLECT:
586
+ return MultiReflectionJudge(rubric, llm_svc, docent_collection_id)
587
+ raise ValueError(f"Invalid variant: {rubric.judge_variant}")