docent-python 0.1.14a0__py3-none-any.whl → 0.1.28a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (46) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +331 -0
  5. docent/_llm_util/llm_cache.py +193 -0
  6. docent/_llm_util/llm_svc.py +472 -0
  7. docent/_llm_util/model_registry.py +130 -0
  8. docent/_llm_util/providers/__init__.py +0 -0
  9. docent/_llm_util/providers/anthropic.py +537 -0
  10. docent/_llm_util/providers/common.py +41 -0
  11. docent/_llm_util/providers/google.py +530 -0
  12. docent/_llm_util/providers/openai.py +745 -0
  13. docent/_llm_util/providers/openrouter.py +375 -0
  14. docent/_llm_util/providers/preference_types.py +104 -0
  15. docent/_llm_util/providers/provider_registry.py +164 -0
  16. docent/data_models/__init__.py +2 -0
  17. docent/data_models/agent_run.py +17 -29
  18. docent/data_models/chat/__init__.py +6 -1
  19. docent/data_models/chat/message.py +3 -1
  20. docent/data_models/citation.py +103 -22
  21. docent/data_models/judge.py +19 -0
  22. docent/data_models/metadata_util.py +16 -0
  23. docent/data_models/remove_invalid_citation_ranges.py +23 -10
  24. docent/data_models/transcript.py +25 -80
  25. docent/data_models/util.py +170 -0
  26. docent/judges/__init__.py +23 -0
  27. docent/judges/analysis.py +77 -0
  28. docent/judges/impl.py +587 -0
  29. docent/judges/runner.py +129 -0
  30. docent/judges/stats.py +205 -0
  31. docent/judges/types.py +311 -0
  32. docent/judges/util/forgiving_json.py +108 -0
  33. docent/judges/util/meta_schema.json +86 -0
  34. docent/judges/util/meta_schema.py +29 -0
  35. docent/judges/util/parse_output.py +87 -0
  36. docent/judges/util/voting.py +139 -0
  37. docent/sdk/agent_run_writer.py +72 -21
  38. docent/sdk/client.py +276 -23
  39. docent/trace.py +413 -90
  40. {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/METADATA +13 -5
  41. docent_python-0.1.28a0.dist-info/RECORD +59 -0
  42. docent/data_models/metadata.py +0 -229
  43. docent/data_models/yaml_util.py +0 -12
  44. docent_python-0.1.14a0.dist-info/RECORD +0 -32
  45. {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/WHEEL +0 -0
  46. {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,129 @@
1
+ from typing import Protocol, Sequence, runtime_checkable
2
+
3
+ import anyio
4
+ from tqdm.auto import tqdm
5
+
6
+ from docent._llm_util.llm_svc import BaseLLMService
7
+ from docent._log_util import get_logger
8
+ from docent.data_models.agent_run import AgentRun
9
+ from docent.judges import (
10
+ JudgeResult,
11
+ JudgeResultCompletionCallback,
12
+ Rubric,
13
+ )
14
+ from docent.judges.impl import build_judge
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ @runtime_checkable
20
+ class AgentRunResolver(Protocol):
21
+ async def __call__(self) -> AgentRun | None: ...
22
+
23
+
24
+ AgentRunInput = AgentRun | AgentRunResolver
25
+
26
+
27
+ async def _resolve_agent_run(agent_run_input: AgentRunInput) -> AgentRun | None:
28
+ if isinstance(agent_run_input, AgentRun):
29
+ return agent_run_input
30
+ else:
31
+ return await agent_run_input()
32
+
33
+
34
+ async def run_rubric(
35
+ agent_runs: Sequence[AgentRunInput],
36
+ rubric: Rubric,
37
+ llm_svc: BaseLLMService,
38
+ callback: JudgeResultCompletionCallback | None = None,
39
+ *,
40
+ n_rollouts_per_input: int | list[int] = 1,
41
+ show_progress: bool = True,
42
+ ) -> list[JudgeResult | None]:
43
+ if not agent_runs:
44
+ raise ValueError("agent_runs must be a non-empty sequence")
45
+ if rubric.n_rollouts_per_input <= 0:
46
+ raise ValueError("rubric.n_rollouts_per_input must be greater than 0")
47
+
48
+ # Normalize n_rollouts_per_input to a list
49
+ if isinstance(n_rollouts_per_input, int):
50
+ if n_rollouts_per_input < 0:
51
+ raise ValueError("n_rollouts_per_input must be non-negative")
52
+ rollouts_per_run = [n_rollouts_per_input] * len(agent_runs)
53
+ else:
54
+ rollouts_per_run = n_rollouts_per_input
55
+ if len(rollouts_per_run) != len(agent_runs):
56
+ raise ValueError("n_rollouts_per_input list must match agent_runs length")
57
+ if any(n < 0 for n in rollouts_per_run):
58
+ raise ValueError("All values in n_rollouts_per_input must be non-negative")
59
+
60
+ judge = build_judge(rubric, llm_svc)
61
+
62
+ total_rollouts = sum(rollouts_per_run)
63
+ logger.info(
64
+ "Running rubric %s version %s against %d agent runs with %d total rollouts",
65
+ rubric.id,
66
+ rubric.version,
67
+ len(agent_runs),
68
+ total_rollouts,
69
+ )
70
+
71
+ agent_results: list[list[JudgeResult | None]] = [[] for _ in agent_runs]
72
+ progress_bar = tqdm(
73
+ total=total_rollouts,
74
+ desc=f"Rubric {rubric.id}",
75
+ disable=not show_progress,
76
+ )
77
+
78
+ # NOTE(mengk): using a (2 * llm max concurrency) semaphore is a hack to avoid
79
+ # hammering _resolve_agent_run, which makes expensive DB calls, when they aren't going to be
80
+ # immediately processed by the LLMService anyways.
81
+ # TODO(mengk): We should eventually implement a more idiomatic solution to this.
82
+ # It's related to the idea of a global concurrency limiter.
83
+ run_judge_semaphore = anyio.Semaphore(llm_svc.max_concurrency * 2)
84
+
85
+ async def _run_single_judge(index: int, agent_run_input: AgentRunInput):
86
+ async with run_judge_semaphore:
87
+ rollout_results: list[JudgeResult | None] = []
88
+
89
+ if rollouts_per_run[index] == 0:
90
+ agent_results[index] = []
91
+ if callback is not None:
92
+ await callback(index, None)
93
+ return
94
+
95
+ agent_run = await _resolve_agent_run(agent_run_input)
96
+ if agent_run is None:
97
+ if callback is not None:
98
+ await callback(index, None)
99
+ return
100
+
101
+ for _ in range(rollouts_per_run[index]):
102
+ result = await judge(agent_run)
103
+ rollout_results.append(result)
104
+ progress_bar.update()
105
+
106
+ agent_results[index] = rollout_results
107
+
108
+ if callback is not None:
109
+ # Filter out None results for the callback
110
+ valid_results = [r for r in rollout_results if r is not None]
111
+ await callback(index, valid_results if valid_results else None)
112
+
113
+ try:
114
+ async with anyio.create_task_group() as tg:
115
+ for index, agent_run in enumerate(agent_runs):
116
+ tg.start_soon(_run_single_judge, index, agent_run)
117
+ finally:
118
+ progress_bar.close()
119
+
120
+ flattened_results = [result for rollouts in agent_results for result in rollouts]
121
+ successful = sum(result is not None for result in flattened_results)
122
+ logger.info(
123
+ "Finished rubric %s: produced %d/%d judge results",
124
+ rubric.id,
125
+ successful,
126
+ len(flattened_results),
127
+ )
128
+
129
+ return flattened_results
docent/judges/stats.py ADDED
@@ -0,0 +1,205 @@
1
+ from typing import Iterator, List, Tuple
2
+
3
+ from scipy import stats
4
+
5
+ from docent._log_util import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ def print_stats_with_intervals(name: str, mean: float, std: float, confidence_levels: list[float]):
11
+ """Print statistics with confidence intervals at multiple confidence levels.
12
+
13
+ Args:
14
+ name: Name of the statistic
15
+ mean: Mean value
16
+ std: Standard deviation
17
+ confidence_levels: List of confidence levels (e.g., [0.90, 0.95, 0.99])
18
+ """
19
+ intervals_str = ", ".join(
20
+ [
21
+ f"{int(level*100)}% interval: [{mean - stats.norm.ppf((1+level)/2) * std:.4f}, {mean + stats.norm.ppf((1+level)/2) * std:.4f}]" # type: ignore
22
+ for level in confidence_levels
23
+ ]
24
+ )
25
+ print(f"{name} mean: {mean:.4f}, std: {std:.4f}, {intervals_str}")
26
+
27
+
28
+ def _bounded_compositions(total: int, parts: int, bound: int) -> Iterator[Tuple[int, ...]]:
29
+ """
30
+ Yield all tuples (x1,...,x_parts) of nonnegative ints summing to `total`
31
+ with each xk <= bound.
32
+ """
33
+
34
+ # Recursive backtracking with pruning by remaining capacity.
35
+ def rec(k: int, remaining: int, prefix: List[int]) -> Iterator[Tuple[int, ...]]:
36
+ if k == parts:
37
+ if remaining == 0:
38
+ yield tuple(prefix)
39
+ return
40
+ # Max we can put here is min(bound, remaining - min_needed_for_rest)
41
+ # The min needed for the rest is 0; also cannot exceed remaining.
42
+ max_here = min(bound, remaining)
43
+ # Optional pruning: ensure the rest can absorb what's left (always true since min=0)
44
+ for x in range(max_here + 1):
45
+ prefix.append(x)
46
+ yield from rec(k + 1, remaining - x, prefix)
47
+ prefix.pop()
48
+
49
+ yield from rec(0, total, [])
50
+
51
+
52
+ def plurality_vectors(m: int, K: int, i: int) -> Iterator[Tuple[int, ...]]:
53
+ """
54
+ Generate all count vectors n = (n1,...,nm) of nonnegative integers with
55
+ sum(n) = K and STRICT plurality at index i:
56
+ n[i] > n[j] for all j != i.
57
+
58
+ Yields vectors in no particular order.
59
+ """
60
+ if not (0 <= i < m):
61
+ raise ValueError("i must be in [0, m).")
62
+ if m < 2 or K < 1:
63
+ return # nothing to yield in degenerate cases
64
+
65
+ for ni in range(1, K + 1): # at least 1 vote for the winner
66
+ rest_total = K - ni
67
+ cap = ni - 1 # strict plurality: others must be <= ni-1
68
+ # If cap < 0 but rest_total > 0, impossible
69
+ if cap < 0 and rest_total > 0:
70
+ continue
71
+ # Build the other m-1 counts under the cap
72
+ for others in _bounded_compositions(rest_total, m - 1, cap):
73
+ # Stitch back in ni at position i
74
+ vec = list(others[:i]) + [ni] + list(others[i:])
75
+ yield tuple(vec)
76
+
77
+
78
+ def p_mode(n: int, p_v: list[float], idx: int) -> float:
79
+ """Probability that the modal sample of sampling Multinom(n, p_v) is the idxth one."""
80
+ count_vecs = plurality_vectors(len(p_v), n, idx)
81
+ return sum(stats.multinomial.pmf(vec, n, p_v) for vec in count_vecs) # type: ignore
82
+
83
+
84
+ # async def analyze_majority_judge(
85
+ # rubric: Rubric,
86
+ # agent_runs: list[AgentRun],
87
+ # matched_labels: dict[str, dict[str, Any]], # agent_run_id -> gold label obj
88
+ # results_path: Path,
89
+ # samples_per_agent_run: int = 10,
90
+ # maj_k: int = 5, # Does not affect data collection
91
+ # max_llm_concurrency: int = 100,
92
+ # ):
93
+ # # if rubric.n_rollouts_per_input != 1:
94
+ # # raise ValueError("You should use n_rollouts_per_input=1")
95
+
96
+ # if not results_path.exists():
97
+ # logger.info(f"Evaluating rubrics and saving results to {results_path}")
98
+
99
+ # max_conc_per_rubric = min(
100
+ # max_llm_concurrency, len(agent_runs) * rubric.n_rollouts_per_input
101
+ # )
102
+ # max_parallel_rubrics = max(1, max_llm_concurrency // max_conc_per_rubric)
103
+ # logger.info(
104
+ # f"Evaluating {samples_per_agent_run} samples per agent run, {max_conc_per_rubric} concurrent LLM calls per rubric, {max_parallel_rubrics} parallel rubrics"
105
+ # )
106
+
107
+ # await evaluate_multiple_rubrics(
108
+ # rubrics=[rubric] * samples_per_agent_run,
109
+ # agent_runs=agent_runs,
110
+ # llm_svc=SimpleLLMService(),
111
+ # output_path=results_path,
112
+ # max_concurrent_llm_calls_per_rubric=max_conc_per_rubric,
113
+ # max_parallel_rubrics=max_parallel_rubrics,
114
+ # )
115
+ # else:
116
+ # logger.info(f"Found existing results at {results_path}, loading them")
117
+
118
+ # rows = load_rubric_results_from_file(results_path)
119
+
120
+ # # Parse results into a flat dataframe
121
+ # parsed_results: list[dict[str, Any]] = []
122
+ # for row in rows:
123
+ # config_key = row.rubric.model_dump_json(
124
+ # exclude={
125
+ # "rubric_text",
126
+ # "id",
127
+ # "version",
128
+ # "system_prompt_template",
129
+ # "citation_instructions",
130
+ # "output_schema",
131
+ # }
132
+ # )
133
+ # for result in row.judge_results:
134
+ # if result is None:
135
+ # print("found null")
136
+ # continue
137
+ # # raise ValueError("result is None")
138
+
139
+ # gold_label = matched_labels[result.agent_run_id]["label"]
140
+ # judge_label = result.output["label"]
141
+ # md = result.result_metadata or {}
142
+
143
+ # parsed_results.append(
144
+ # {
145
+ # "config_key": config_key,
146
+ # "agent_run_id": result.agent_run_id,
147
+ # "gold_label": gold_label,
148
+ # "judge_label": judge_label,
149
+ # "final_results": md.get("final_results"),
150
+ # "is_correct": gold_label == judge_label,
151
+ # }
152
+ # )
153
+
154
+ # df = pd.DataFrame(parsed_results)
155
+ # potential_labels = ["cheating", "future leakage", "not suspicious"]
156
+
157
+ # def _get_pred_dist(x: pd.Series):
158
+ # x = x.tolist()
159
+ # counts = [x.count(l) for l in potential_labels]
160
+ # assert np.sum(counts) == len(x)
161
+ # return {l: counts[i] / len(x) for i, l in enumerate(potential_labels)}
162
+
163
+ # n_ars = len(df.groupby("agent_run_id").count())
164
+ # p_correct = (
165
+ # df.groupby("agent_run_id")
166
+ # .agg(
167
+ # {
168
+ # "gold_label": lambda x: x.iloc[0],
169
+ # "judge_label": _get_pred_dist,
170
+ # "is_correct": np.mean,
171
+ # }
172
+ # )
173
+ # .rename(columns={"judge_label": "pred_dist", "is_correct": "p_correct_naive"})
174
+ # )
175
+ # p_correct["p_correct_naive_var"] = p_correct["p_correct_naive"].apply(lambda x: x * (1 - x))
176
+
177
+ # p_correct["p_correct_majority"] = p_correct.apply(
178
+ # lambda row: p_mode(
179
+ # maj_k,
180
+ # [row["pred_dist"][l] for l in potential_labels],
181
+ # potential_labels.index(row["gold_label"]),
182
+ # ),
183
+ # axis=1,
184
+ # )
185
+ # p_correct["p_correct_majority_var"] = p_correct["p_correct_majority"].apply(
186
+ # lambda x: x * (1 - x)
187
+ # )
188
+ # p_correct.sort_values(by="p_correct_majority_var", ascending=False, inplace=True)
189
+
190
+ # overall_naive_mean = p_correct["p_correct_naive"].mean()
191
+ # overall_naive_std = np.sqrt(p_correct["p_correct_naive_var"].sum() / n_ars**2)
192
+ # overall_majority_mean = p_correct["p_correct_majority"].mean()
193
+ # overall_majority_std = np.sqrt(p_correct["p_correct_majority_var"].sum() / n_ars**2)
194
+
195
+ # confidence_levels = [0.5, 0.95]
196
+ # print_stats_with_intervals(
197
+ # "Overall naive", overall_naive_mean, overall_naive_std, confidence_levels
198
+ # )
199
+ # print_stats_with_intervals(
200
+ # f"Overall majority (k={maj_k})",
201
+ # overall_majority_mean,
202
+ # overall_majority_std,
203
+ # confidence_levels,
204
+ # )
205
+ # return p_correct
docent/judges/types.py ADDED
@@ -0,0 +1,311 @@
1
+ import enum
2
+ import json
3
+ from string import Formatter
4
+ from typing import Any, Callable, Literal, Protocol
5
+ from uuid import uuid4
6
+
7
+ from pydantic import BaseModel, Field, field_serializer, field_validator
8
+
9
+ from docent._llm_util.providers.preference_types import PUBLIC_PROVIDER_PREFERENCES, ModelOption
10
+ from docent._log_util import get_logger
11
+ from docent.data_models.agent_run import AgentRun
12
+ from docent.data_models.citation import parse_citations
13
+ from docent.data_models.transcript import TEXT_RANGE_CITE_INSTRUCTION
14
+ from docent.judges.util.meta_schema import validate_judge_result_schema
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
19
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
20
+
21
+ Rubric:
22
+ <rubric>
23
+ {rubric}
24
+ </rubric>
25
+
26
+ Agent run:
27
+ <agent_run>
28
+ {agent_run}
29
+ </agent_run>
30
+
31
+ Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step.
32
+
33
+ When you are finished, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
34
+
35
+ The JSON object you produce must adhere to the following schema:
36
+ {output_schema}
37
+
38
+ {citation_instructions}
39
+ """.strip()
40
+
41
+ DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
42
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
43
+
44
+ Rubric:
45
+ <rubric>
46
+ {rubric}
47
+ </rubric>
48
+
49
+ Agent run:
50
+ <agent_run>
51
+ {agent_run}
52
+ </agent_run>
53
+
54
+ Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must execute **one step in the decision procedure per assistant message turn**. After each turn, output a complete and detailed recount of all actions you took, and everything you discovered. Then call the `step_finished` tool.
55
+
56
+ When you are finished going through the decision procedure, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
57
+
58
+ The JSON object you produce must adhere to the following schema:
59
+ {output_schema}
60
+
61
+ {citation_instructions}
62
+ """.strip()
63
+
64
+ DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
65
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
66
+
67
+ Rubric:
68
+ <rubric>
69
+ {rubric}
70
+ </rubric>
71
+
72
+ Agent run:
73
+ <agent_run>
74
+ {agent_run}
75
+ </agent_run>
76
+
77
+ Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must *fully externalize* your reasoning work by outputting details in the assistant message, surrounded by <reasoning>...</reasoning> tags. The reasoning section can be as messy as you need. You should use *high* reasoning effort.
78
+
79
+ When you are finished, output your final adjudication in the assistant message, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
80
+
81
+ The JSON object you produce must adhere to the following schema:
82
+ {output_schema}
83
+
84
+ {citation_instructions}
85
+ """.strip()
86
+
87
+ DEFAULT_CITATION_INSTRUCTIONS = f"""
88
+ For strings which require citations (according to the `citations: True` property), you must also follow these instructions:
89
+ {TEXT_RANGE_CITE_INSTRUCTION}
90
+ """.strip()
91
+
92
+ DEFAULT_JUDGE_OUTPUT_SCHEMA = {
93
+ "type": "object",
94
+ "properties": {
95
+ "label": {"type": "string", "enum": ["match", "no match"]},
96
+ "explanation": {"type": "string", "citations": True},
97
+ },
98
+ # Require these properties to be present
99
+ "required": ["label", "explanation"],
100
+ # Allow additional properties though, as their presence is not breaking
101
+ }
102
+
103
+ DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
104
+
105
+
106
+ class JudgeVariant(str, enum.Enum):
107
+ MAJORITY = "majority"
108
+ MULTI_REFLECT = "multi-reflect"
109
+
110
+
111
+ class Rubric(BaseModel):
112
+ """TODO(mengk): this should really be called JudgeConfig,
113
+ but temporarily keeping this for consistency with docent_core."""
114
+
115
+ class Config:
116
+ frozen = True
117
+
118
+ # Primary key
119
+ id: str = Field(default_factory=lambda: str(uuid4()))
120
+ version: int = 1
121
+
122
+ # What the judge actually does
123
+ rubric_text: str
124
+ n_rollouts_per_input: int = 1
125
+ judge_variant: JudgeVariant = JudgeVariant.MAJORITY
126
+ # TODO(mengk): add this to the database
127
+ # No need right now because multi-turn is still very experimental.
128
+ rollout_type: Literal["single_turn", "multi_turn"] = "single_turn"
129
+
130
+ # Default instructions for the judge
131
+ system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
132
+ citation_instructions: str = DEFAULT_CITATION_INSTRUCTIONS
133
+ output_schema: dict[str, Any] = DEFAULT_JUDGE_OUTPUT_SCHEMA
134
+
135
+ # How to run the judge
136
+ judge_model: ModelOption = DEFAULT_JUDGE_MODEL
137
+
138
+ def materialize_system_prompt(self, agent_run: AgentRun) -> str:
139
+ """Construct the full prompt text for rubric evaluation.
140
+
141
+ This is the canonical implementation of prompt construction - use this function
142
+ anywhere you need to construct a rubric evaluation prompt (including cost estimation).
143
+ """
144
+
145
+ output_schema_text = json.dumps(self.output_schema, indent=2)
146
+
147
+ # We've already validated that the system prompt template has these keys
148
+ prompt = self.system_prompt_template.format(
149
+ rubric=self.rubric_text,
150
+ agent_run=agent_run.to_text_new(),
151
+ output_schema=output_schema_text,
152
+ # Only include citation instructions if the schema requests citations
153
+ citation_instructions=(
154
+ self.citation_instructions if _schema_requests_citations(self.output_schema) else ""
155
+ ),
156
+ ).strip()
157
+
158
+ return prompt
159
+
160
+ @field_validator("system_prompt_template")
161
+ @classmethod
162
+ def validate_system_prompt_template(cls, system_prompt_template: str):
163
+ # Extract all field names from the template
164
+ formatter = Formatter()
165
+ field_names = {
166
+ field_name
167
+ for _, field_name, _, _ in formatter.parse(system_prompt_template)
168
+ if field_name is not None
169
+ }
170
+
171
+ # Check for required fields
172
+ required_fields = {"agent_run", "output_schema", "rubric", "citation_instructions"}
173
+ missing_fields = required_fields - field_names
174
+
175
+ if missing_fields:
176
+ raise ValueError(
177
+ f"system_prompt_template must contain the following placeholders: {missing_fields}"
178
+ )
179
+
180
+ return system_prompt_template
181
+
182
+ @field_validator("output_schema")
183
+ @classmethod
184
+ def validate_output_schema(cls, output_schema: dict[str, Any]):
185
+ """
186
+ Raises:
187
+ jsonschema.ValidationError: If the schema is invalid
188
+ jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
189
+ """
190
+ validate_judge_result_schema(output_schema)
191
+ return output_schema
192
+
193
+
194
+ class MultiTurnRubric(Rubric):
195
+ system_prompt_template: str = DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE
196
+ rollout_type: Literal["single_turn", "multi_turn"] = "multi_turn"
197
+
198
+
199
+ class ExposedReasoningRubric(Rubric):
200
+ system_prompt_template: str = DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE
201
+
202
+
203
+ class ResultType(enum.Enum):
204
+ """Enum for the type of result that a judge result can have."""
205
+
206
+ DIRECT_RESULT = "direct_result"
207
+ NEAR_MISS = "near_miss"
208
+
209
+
210
+ class JudgeResult(BaseModel):
211
+ class Config:
212
+ frozen = True
213
+
214
+ id: str = Field(default_factory=lambda: str(uuid4()))
215
+ agent_run_id: str
216
+ rubric_id: str
217
+ rubric_version: int
218
+
219
+ # Outputs
220
+ output: dict[str, Any]
221
+ result_metadata: dict[str, Any] | None = None
222
+ result_type: ResultType
223
+
224
+ # Deprecated
225
+ value: str | None = None
226
+
227
+ @field_serializer("result_type")
228
+ def serialize_result_type(self, result_type: ResultType) -> str:
229
+ return result_type.value
230
+
231
+
232
+ class JudgeResultWithCitations(JudgeResult):
233
+ @classmethod
234
+ def from_judge_result(
235
+ cls, result: JudgeResult, schema: dict[str, Any]
236
+ ) -> "JudgeResultWithCitations":
237
+ """Judge result must be validated against the schema before calling this function!"""
238
+
239
+ def _parse_citation_string(output: str) -> dict[str, Any]:
240
+ text, citations = parse_citations(output)
241
+ return {"text": text, "citations": citations}
242
+
243
+ data = result.model_dump()
244
+ try:
245
+ data["output"] = traverse_schema_and_transform(
246
+ data["output"], schema, _parse_citation_string
247
+ )
248
+ except Exception as e:
249
+ logger.error(f"Failed to parse citations: {e}")
250
+ logger.error(f"Output: {data['output']}")
251
+ data["output"] = {"raw": data["output"]}
252
+ return cls(**data)
253
+
254
+
255
+ class JudgeResultCompletionCallback(Protocol):
256
+ """Called when some batch of judge results is completed.
257
+ Supports batched calls for cases where many results are pre-computed.
258
+ This avoids invoking the callback separately for each datapoint.
259
+ """
260
+
261
+ async def __call__(
262
+ self,
263
+ batch_index: int,
264
+ judge_results: list[JudgeResult] | None,
265
+ ) -> None: ...
266
+
267
+
268
+ def traverse_schema_and_transform(
269
+ output: Any,
270
+ schema: dict[str, Any],
271
+ citation_string_handler: Callable[[str], Any],
272
+ ) -> Any:
273
+ """Recursively traverse output based on schema, applying citation_string_handler to citation strings."""
274
+ if schema.get("type") == "string" and schema.get("citations"): # type: ignore
275
+ return citation_string_handler(output)
276
+ elif schema.get("type") == "object":
277
+ properties: dict[str, Any] = schema.get("properties", {})
278
+ result: dict[str, Any] = {}
279
+ for key in properties:
280
+ if key in output:
281
+ result[key] = traverse_schema_and_transform(
282
+ output[key], properties[key], citation_string_handler
283
+ )
284
+ return result
285
+ elif schema.get("type") == "array":
286
+ item_schema: dict[str, Any] = schema.get("items", {})
287
+ return [
288
+ traverse_schema_and_transform(item, item_schema, citation_string_handler)
289
+ for item in output
290
+ ]
291
+ else:
292
+ return output
293
+
294
+
295
+ def _schema_requests_citations(schema: dict[str, Any]) -> bool:
296
+ """Check if any field in the schema requests citations by having 'citations': 'true'."""
297
+
298
+ def _check_field(field_schema: Any) -> bool:
299
+ if isinstance(field_schema, dict):
300
+ if field_schema.get("citations"): # type: ignore
301
+ return True
302
+ for value in field_schema.values(): # type: ignore
303
+ if isinstance(value, dict) and _check_field(value):
304
+ return True
305
+ elif isinstance(value, list):
306
+ for item in value: # type: ignore
307
+ if isinstance(item, dict) and _check_field(item):
308
+ return True
309
+ return False
310
+
311
+ return _check_field(schema)