docent-python 0.1.17a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +130 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/__init__.py +2 -0
- docent/data_models/agent_run.py +6 -5
- docent/data_models/chat/__init__.py +6 -1
- docent/data_models/citation.py +103 -22
- docent/data_models/judge.py +19 -0
- docent/data_models/metadata_util.py +16 -0
- docent/data_models/remove_invalid_citation_ranges.py +23 -10
- docent/data_models/transcript.py +20 -16
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +311 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +87 -0
- docent/judges/util/voting.py +139 -0
- docent/sdk/agent_run_writer.py +62 -19
- docent/sdk/client.py +244 -23
- docent/trace.py +413 -90
- {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
- docent_python-0.1.27a0.dist-info/RECORD +59 -0
- docent/data_models/metadata.py +0 -229
- docent/data_models/yaml_util.py +0 -12
- docent_python-0.1.17a0.dist-info/RECORD +0 -32
- {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
docent/judges/runner.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from typing import Protocol, Sequence, runtime_checkable
|
|
2
|
+
|
|
3
|
+
import anyio
|
|
4
|
+
from tqdm.auto import tqdm
|
|
5
|
+
|
|
6
|
+
from docent._llm_util.llm_svc import BaseLLMService
|
|
7
|
+
from docent._log_util import get_logger
|
|
8
|
+
from docent.data_models.agent_run import AgentRun
|
|
9
|
+
from docent.judges import (
|
|
10
|
+
JudgeResult,
|
|
11
|
+
JudgeResultCompletionCallback,
|
|
12
|
+
Rubric,
|
|
13
|
+
)
|
|
14
|
+
from docent.judges.impl import build_judge
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@runtime_checkable
|
|
20
|
+
class AgentRunResolver(Protocol):
|
|
21
|
+
async def __call__(self) -> AgentRun | None: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
AgentRunInput = AgentRun | AgentRunResolver
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def _resolve_agent_run(agent_run_input: AgentRunInput) -> AgentRun | None:
|
|
28
|
+
if isinstance(agent_run_input, AgentRun):
|
|
29
|
+
return agent_run_input
|
|
30
|
+
else:
|
|
31
|
+
return await agent_run_input()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def run_rubric(
|
|
35
|
+
agent_runs: Sequence[AgentRunInput],
|
|
36
|
+
rubric: Rubric,
|
|
37
|
+
llm_svc: BaseLLMService,
|
|
38
|
+
callback: JudgeResultCompletionCallback | None = None,
|
|
39
|
+
*,
|
|
40
|
+
n_rollouts_per_input: int | list[int] = 1,
|
|
41
|
+
show_progress: bool = True,
|
|
42
|
+
) -> list[JudgeResult | None]:
|
|
43
|
+
if not agent_runs:
|
|
44
|
+
raise ValueError("agent_runs must be a non-empty sequence")
|
|
45
|
+
if rubric.n_rollouts_per_input <= 0:
|
|
46
|
+
raise ValueError("rubric.n_rollouts_per_input must be greater than 0")
|
|
47
|
+
|
|
48
|
+
# Normalize n_rollouts_per_input to a list
|
|
49
|
+
if isinstance(n_rollouts_per_input, int):
|
|
50
|
+
if n_rollouts_per_input < 0:
|
|
51
|
+
raise ValueError("n_rollouts_per_input must be non-negative")
|
|
52
|
+
rollouts_per_run = [n_rollouts_per_input] * len(agent_runs)
|
|
53
|
+
else:
|
|
54
|
+
rollouts_per_run = n_rollouts_per_input
|
|
55
|
+
if len(rollouts_per_run) != len(agent_runs):
|
|
56
|
+
raise ValueError("n_rollouts_per_input list must match agent_runs length")
|
|
57
|
+
if any(n < 0 for n in rollouts_per_run):
|
|
58
|
+
raise ValueError("All values in n_rollouts_per_input must be non-negative")
|
|
59
|
+
|
|
60
|
+
judge = build_judge(rubric, llm_svc)
|
|
61
|
+
|
|
62
|
+
total_rollouts = sum(rollouts_per_run)
|
|
63
|
+
logger.info(
|
|
64
|
+
"Running rubric %s version %s against %d agent runs with %d total rollouts",
|
|
65
|
+
rubric.id,
|
|
66
|
+
rubric.version,
|
|
67
|
+
len(agent_runs),
|
|
68
|
+
total_rollouts,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
agent_results: list[list[JudgeResult | None]] = [[] for _ in agent_runs]
|
|
72
|
+
progress_bar = tqdm(
|
|
73
|
+
total=total_rollouts,
|
|
74
|
+
desc=f"Rubric {rubric.id}",
|
|
75
|
+
disable=not show_progress,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# NOTE(mengk): using a (2 * llm max concurrency) semaphore is a hack to avoid
|
|
79
|
+
# hammering _resolve_agent_run, which makes expensive DB calls, when they aren't going to be
|
|
80
|
+
# immediately processed by the LLMService anyways.
|
|
81
|
+
# TODO(mengk): We should eventually implement a more idiomatic solution to this.
|
|
82
|
+
# It's related to the idea of a global concurrency limiter.
|
|
83
|
+
run_judge_semaphore = anyio.Semaphore(llm_svc.max_concurrency * 2)
|
|
84
|
+
|
|
85
|
+
async def _run_single_judge(index: int, agent_run_input: AgentRunInput):
|
|
86
|
+
async with run_judge_semaphore:
|
|
87
|
+
rollout_results: list[JudgeResult | None] = []
|
|
88
|
+
|
|
89
|
+
if rollouts_per_run[index] == 0:
|
|
90
|
+
agent_results[index] = []
|
|
91
|
+
if callback is not None:
|
|
92
|
+
await callback(index, None)
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
agent_run = await _resolve_agent_run(agent_run_input)
|
|
96
|
+
if agent_run is None:
|
|
97
|
+
if callback is not None:
|
|
98
|
+
await callback(index, None)
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
for _ in range(rollouts_per_run[index]):
|
|
102
|
+
result = await judge(agent_run)
|
|
103
|
+
rollout_results.append(result)
|
|
104
|
+
progress_bar.update()
|
|
105
|
+
|
|
106
|
+
agent_results[index] = rollout_results
|
|
107
|
+
|
|
108
|
+
if callback is not None:
|
|
109
|
+
# Filter out None results for the callback
|
|
110
|
+
valid_results = [r for r in rollout_results if r is not None]
|
|
111
|
+
await callback(index, valid_results if valid_results else None)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
async with anyio.create_task_group() as tg:
|
|
115
|
+
for index, agent_run in enumerate(agent_runs):
|
|
116
|
+
tg.start_soon(_run_single_judge, index, agent_run)
|
|
117
|
+
finally:
|
|
118
|
+
progress_bar.close()
|
|
119
|
+
|
|
120
|
+
flattened_results = [result for rollouts in agent_results for result in rollouts]
|
|
121
|
+
successful = sum(result is not None for result in flattened_results)
|
|
122
|
+
logger.info(
|
|
123
|
+
"Finished rubric %s: produced %d/%d judge results",
|
|
124
|
+
rubric.id,
|
|
125
|
+
successful,
|
|
126
|
+
len(flattened_results),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return flattened_results
|
docent/judges/stats.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from typing import Iterator, List, Tuple
|
|
2
|
+
|
|
3
|
+
from scipy import stats
|
|
4
|
+
|
|
5
|
+
from docent._log_util import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def print_stats_with_intervals(name: str, mean: float, std: float, confidence_levels: list[float]):
|
|
11
|
+
"""Print statistics with confidence intervals at multiple confidence levels.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
name: Name of the statistic
|
|
15
|
+
mean: Mean value
|
|
16
|
+
std: Standard deviation
|
|
17
|
+
confidence_levels: List of confidence levels (e.g., [0.90, 0.95, 0.99])
|
|
18
|
+
"""
|
|
19
|
+
intervals_str = ", ".join(
|
|
20
|
+
[
|
|
21
|
+
f"{int(level*100)}% interval: [{mean - stats.norm.ppf((1+level)/2) * std:.4f}, {mean + stats.norm.ppf((1+level)/2) * std:.4f}]" # type: ignore
|
|
22
|
+
for level in confidence_levels
|
|
23
|
+
]
|
|
24
|
+
)
|
|
25
|
+
print(f"{name} mean: {mean:.4f}, std: {std:.4f}, {intervals_str}")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _bounded_compositions(total: int, parts: int, bound: int) -> Iterator[Tuple[int, ...]]:
|
|
29
|
+
"""
|
|
30
|
+
Yield all tuples (x1,...,x_parts) of nonnegative ints summing to `total`
|
|
31
|
+
with each xk <= bound.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Recursive backtracking with pruning by remaining capacity.
|
|
35
|
+
def rec(k: int, remaining: int, prefix: List[int]) -> Iterator[Tuple[int, ...]]:
|
|
36
|
+
if k == parts:
|
|
37
|
+
if remaining == 0:
|
|
38
|
+
yield tuple(prefix)
|
|
39
|
+
return
|
|
40
|
+
# Max we can put here is min(bound, remaining - min_needed_for_rest)
|
|
41
|
+
# The min needed for the rest is 0; also cannot exceed remaining.
|
|
42
|
+
max_here = min(bound, remaining)
|
|
43
|
+
# Optional pruning: ensure the rest can absorb what's left (always true since min=0)
|
|
44
|
+
for x in range(max_here + 1):
|
|
45
|
+
prefix.append(x)
|
|
46
|
+
yield from rec(k + 1, remaining - x, prefix)
|
|
47
|
+
prefix.pop()
|
|
48
|
+
|
|
49
|
+
yield from rec(0, total, [])
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def plurality_vectors(m: int, K: int, i: int) -> Iterator[Tuple[int, ...]]:
|
|
53
|
+
"""
|
|
54
|
+
Generate all count vectors n = (n1,...,nm) of nonnegative integers with
|
|
55
|
+
sum(n) = K and STRICT plurality at index i:
|
|
56
|
+
n[i] > n[j] for all j != i.
|
|
57
|
+
|
|
58
|
+
Yields vectors in no particular order.
|
|
59
|
+
"""
|
|
60
|
+
if not (0 <= i < m):
|
|
61
|
+
raise ValueError("i must be in [0, m).")
|
|
62
|
+
if m < 2 or K < 1:
|
|
63
|
+
return # nothing to yield in degenerate cases
|
|
64
|
+
|
|
65
|
+
for ni in range(1, K + 1): # at least 1 vote for the winner
|
|
66
|
+
rest_total = K - ni
|
|
67
|
+
cap = ni - 1 # strict plurality: others must be <= ni-1
|
|
68
|
+
# If cap < 0 but rest_total > 0, impossible
|
|
69
|
+
if cap < 0 and rest_total > 0:
|
|
70
|
+
continue
|
|
71
|
+
# Build the other m-1 counts under the cap
|
|
72
|
+
for others in _bounded_compositions(rest_total, m - 1, cap):
|
|
73
|
+
# Stitch back in ni at position i
|
|
74
|
+
vec = list(others[:i]) + [ni] + list(others[i:])
|
|
75
|
+
yield tuple(vec)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def p_mode(n: int, p_v: list[float], idx: int) -> float:
|
|
79
|
+
"""Probability that the modal sample of sampling Multinom(n, p_v) is the idxth one."""
|
|
80
|
+
count_vecs = plurality_vectors(len(p_v), n, idx)
|
|
81
|
+
return sum(stats.multinomial.pmf(vec, n, p_v) for vec in count_vecs) # type: ignore
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# async def analyze_majority_judge(
|
|
85
|
+
# rubric: Rubric,
|
|
86
|
+
# agent_runs: list[AgentRun],
|
|
87
|
+
# matched_labels: dict[str, dict[str, Any]], # agent_run_id -> gold label obj
|
|
88
|
+
# results_path: Path,
|
|
89
|
+
# samples_per_agent_run: int = 10,
|
|
90
|
+
# maj_k: int = 5, # Does not affect data collection
|
|
91
|
+
# max_llm_concurrency: int = 100,
|
|
92
|
+
# ):
|
|
93
|
+
# # if rubric.n_rollouts_per_input != 1:
|
|
94
|
+
# # raise ValueError("You should use n_rollouts_per_input=1")
|
|
95
|
+
|
|
96
|
+
# if not results_path.exists():
|
|
97
|
+
# logger.info(f"Evaluating rubrics and saving results to {results_path}")
|
|
98
|
+
|
|
99
|
+
# max_conc_per_rubric = min(
|
|
100
|
+
# max_llm_concurrency, len(agent_runs) * rubric.n_rollouts_per_input
|
|
101
|
+
# )
|
|
102
|
+
# max_parallel_rubrics = max(1, max_llm_concurrency // max_conc_per_rubric)
|
|
103
|
+
# logger.info(
|
|
104
|
+
# f"Evaluating {samples_per_agent_run} samples per agent run, {max_conc_per_rubric} concurrent LLM calls per rubric, {max_parallel_rubrics} parallel rubrics"
|
|
105
|
+
# )
|
|
106
|
+
|
|
107
|
+
# await evaluate_multiple_rubrics(
|
|
108
|
+
# rubrics=[rubric] * samples_per_agent_run,
|
|
109
|
+
# agent_runs=agent_runs,
|
|
110
|
+
# llm_svc=SimpleLLMService(),
|
|
111
|
+
# output_path=results_path,
|
|
112
|
+
# max_concurrent_llm_calls_per_rubric=max_conc_per_rubric,
|
|
113
|
+
# max_parallel_rubrics=max_parallel_rubrics,
|
|
114
|
+
# )
|
|
115
|
+
# else:
|
|
116
|
+
# logger.info(f"Found existing results at {results_path}, loading them")
|
|
117
|
+
|
|
118
|
+
# rows = load_rubric_results_from_file(results_path)
|
|
119
|
+
|
|
120
|
+
# # Parse results into a flat dataframe
|
|
121
|
+
# parsed_results: list[dict[str, Any]] = []
|
|
122
|
+
# for row in rows:
|
|
123
|
+
# config_key = row.rubric.model_dump_json(
|
|
124
|
+
# exclude={
|
|
125
|
+
# "rubric_text",
|
|
126
|
+
# "id",
|
|
127
|
+
# "version",
|
|
128
|
+
# "system_prompt_template",
|
|
129
|
+
# "citation_instructions",
|
|
130
|
+
# "output_schema",
|
|
131
|
+
# }
|
|
132
|
+
# )
|
|
133
|
+
# for result in row.judge_results:
|
|
134
|
+
# if result is None:
|
|
135
|
+
# print("found null")
|
|
136
|
+
# continue
|
|
137
|
+
# # raise ValueError("result is None")
|
|
138
|
+
|
|
139
|
+
# gold_label = matched_labels[result.agent_run_id]["label"]
|
|
140
|
+
# judge_label = result.output["label"]
|
|
141
|
+
# md = result.result_metadata or {}
|
|
142
|
+
|
|
143
|
+
# parsed_results.append(
|
|
144
|
+
# {
|
|
145
|
+
# "config_key": config_key,
|
|
146
|
+
# "agent_run_id": result.agent_run_id,
|
|
147
|
+
# "gold_label": gold_label,
|
|
148
|
+
# "judge_label": judge_label,
|
|
149
|
+
# "final_results": md.get("final_results"),
|
|
150
|
+
# "is_correct": gold_label == judge_label,
|
|
151
|
+
# }
|
|
152
|
+
# )
|
|
153
|
+
|
|
154
|
+
# df = pd.DataFrame(parsed_results)
|
|
155
|
+
# potential_labels = ["cheating", "future leakage", "not suspicious"]
|
|
156
|
+
|
|
157
|
+
# def _get_pred_dist(x: pd.Series):
|
|
158
|
+
# x = x.tolist()
|
|
159
|
+
# counts = [x.count(l) for l in potential_labels]
|
|
160
|
+
# assert np.sum(counts) == len(x)
|
|
161
|
+
# return {l: counts[i] / len(x) for i, l in enumerate(potential_labels)}
|
|
162
|
+
|
|
163
|
+
# n_ars = len(df.groupby("agent_run_id").count())
|
|
164
|
+
# p_correct = (
|
|
165
|
+
# df.groupby("agent_run_id")
|
|
166
|
+
# .agg(
|
|
167
|
+
# {
|
|
168
|
+
# "gold_label": lambda x: x.iloc[0],
|
|
169
|
+
# "judge_label": _get_pred_dist,
|
|
170
|
+
# "is_correct": np.mean,
|
|
171
|
+
# }
|
|
172
|
+
# )
|
|
173
|
+
# .rename(columns={"judge_label": "pred_dist", "is_correct": "p_correct_naive"})
|
|
174
|
+
# )
|
|
175
|
+
# p_correct["p_correct_naive_var"] = p_correct["p_correct_naive"].apply(lambda x: x * (1 - x))
|
|
176
|
+
|
|
177
|
+
# p_correct["p_correct_majority"] = p_correct.apply(
|
|
178
|
+
# lambda row: p_mode(
|
|
179
|
+
# maj_k,
|
|
180
|
+
# [row["pred_dist"][l] for l in potential_labels],
|
|
181
|
+
# potential_labels.index(row["gold_label"]),
|
|
182
|
+
# ),
|
|
183
|
+
# axis=1,
|
|
184
|
+
# )
|
|
185
|
+
# p_correct["p_correct_majority_var"] = p_correct["p_correct_majority"].apply(
|
|
186
|
+
# lambda x: x * (1 - x)
|
|
187
|
+
# )
|
|
188
|
+
# p_correct.sort_values(by="p_correct_majority_var", ascending=False, inplace=True)
|
|
189
|
+
|
|
190
|
+
# overall_naive_mean = p_correct["p_correct_naive"].mean()
|
|
191
|
+
# overall_naive_std = np.sqrt(p_correct["p_correct_naive_var"].sum() / n_ars**2)
|
|
192
|
+
# overall_majority_mean = p_correct["p_correct_majority"].mean()
|
|
193
|
+
# overall_majority_std = np.sqrt(p_correct["p_correct_majority_var"].sum() / n_ars**2)
|
|
194
|
+
|
|
195
|
+
# confidence_levels = [0.5, 0.95]
|
|
196
|
+
# print_stats_with_intervals(
|
|
197
|
+
# "Overall naive", overall_naive_mean, overall_naive_std, confidence_levels
|
|
198
|
+
# )
|
|
199
|
+
# print_stats_with_intervals(
|
|
200
|
+
# f"Overall majority (k={maj_k})",
|
|
201
|
+
# overall_majority_mean,
|
|
202
|
+
# overall_majority_std,
|
|
203
|
+
# confidence_levels,
|
|
204
|
+
# )
|
|
205
|
+
# return p_correct
|
docent/judges/types.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import json
|
|
3
|
+
from string import Formatter
|
|
4
|
+
from typing import Any, Callable, Literal, Protocol
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
8
|
+
|
|
9
|
+
from docent._llm_util.providers.preference_types import PUBLIC_PROVIDER_PREFERENCES, ModelOption
|
|
10
|
+
from docent._log_util import get_logger
|
|
11
|
+
from docent.data_models.agent_run import AgentRun
|
|
12
|
+
from docent.data_models.citation import parse_citations
|
|
13
|
+
from docent.data_models.transcript import TEXT_RANGE_CITE_INSTRUCTION
|
|
14
|
+
from docent.judges.util.meta_schema import validate_judge_result_schema
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
19
|
+
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
20
|
+
|
|
21
|
+
Rubric:
|
|
22
|
+
<rubric>
|
|
23
|
+
{rubric}
|
|
24
|
+
</rubric>
|
|
25
|
+
|
|
26
|
+
Agent run:
|
|
27
|
+
<agent_run>
|
|
28
|
+
{agent_run}
|
|
29
|
+
</agent_run>
|
|
30
|
+
|
|
31
|
+
Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step.
|
|
32
|
+
|
|
33
|
+
When you are finished, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
34
|
+
|
|
35
|
+
The JSON object you produce must adhere to the following schema:
|
|
36
|
+
{output_schema}
|
|
37
|
+
|
|
38
|
+
{citation_instructions}
|
|
39
|
+
""".strip()
|
|
40
|
+
|
|
41
|
+
DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
42
|
+
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
43
|
+
|
|
44
|
+
Rubric:
|
|
45
|
+
<rubric>
|
|
46
|
+
{rubric}
|
|
47
|
+
</rubric>
|
|
48
|
+
|
|
49
|
+
Agent run:
|
|
50
|
+
<agent_run>
|
|
51
|
+
{agent_run}
|
|
52
|
+
</agent_run>
|
|
53
|
+
|
|
54
|
+
Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must execute **one step in the decision procedure per assistant message turn**. After each turn, output a complete and detailed recount of all actions you took, and everything you discovered. Then call the `step_finished` tool.
|
|
55
|
+
|
|
56
|
+
When you are finished going through the decision procedure, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
57
|
+
|
|
58
|
+
The JSON object you produce must adhere to the following schema:
|
|
59
|
+
{output_schema}
|
|
60
|
+
|
|
61
|
+
{citation_instructions}
|
|
62
|
+
""".strip()
|
|
63
|
+
|
|
64
|
+
DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
65
|
+
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
66
|
+
|
|
67
|
+
Rubric:
|
|
68
|
+
<rubric>
|
|
69
|
+
{rubric}
|
|
70
|
+
</rubric>
|
|
71
|
+
|
|
72
|
+
Agent run:
|
|
73
|
+
<agent_run>
|
|
74
|
+
{agent_run}
|
|
75
|
+
</agent_run>
|
|
76
|
+
|
|
77
|
+
Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must *fully externalize* your reasoning work by outputting details in the assistant message, surrounded by <reasoning>...</reasoning> tags. The reasoning section can be as messy as you need. You should use *high* reasoning effort.
|
|
78
|
+
|
|
79
|
+
When you are finished, output your final adjudication in the assistant message, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
80
|
+
|
|
81
|
+
The JSON object you produce must adhere to the following schema:
|
|
82
|
+
{output_schema}
|
|
83
|
+
|
|
84
|
+
{citation_instructions}
|
|
85
|
+
""".strip()
|
|
86
|
+
|
|
87
|
+
DEFAULT_CITATION_INSTRUCTIONS = f"""
|
|
88
|
+
For strings which require citations (according to the `citations: True` property), you must also follow these instructions:
|
|
89
|
+
{TEXT_RANGE_CITE_INSTRUCTION}
|
|
90
|
+
""".strip()
|
|
91
|
+
|
|
92
|
+
DEFAULT_JUDGE_OUTPUT_SCHEMA = {
|
|
93
|
+
"type": "object",
|
|
94
|
+
"properties": {
|
|
95
|
+
"label": {"type": "string", "enum": ["match", "no match"]},
|
|
96
|
+
"explanation": {"type": "string", "citations": True},
|
|
97
|
+
},
|
|
98
|
+
# Require these properties to be present
|
|
99
|
+
"required": ["label", "explanation"],
|
|
100
|
+
# Allow additional properties though, as their presence is not breaking
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class JudgeVariant(str, enum.Enum):
|
|
107
|
+
MAJORITY = "majority"
|
|
108
|
+
MULTI_REFLECT = "multi-reflect"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class Rubric(BaseModel):
|
|
112
|
+
"""TODO(mengk): this should really be called JudgeConfig,
|
|
113
|
+
but temporarily keeping this for consistency with docent_core."""
|
|
114
|
+
|
|
115
|
+
class Config:
|
|
116
|
+
frozen = True
|
|
117
|
+
|
|
118
|
+
# Primary key
|
|
119
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
120
|
+
version: int = 1
|
|
121
|
+
|
|
122
|
+
# What the judge actually does
|
|
123
|
+
rubric_text: str
|
|
124
|
+
n_rollouts_per_input: int = 1
|
|
125
|
+
judge_variant: JudgeVariant = JudgeVariant.MAJORITY
|
|
126
|
+
# TODO(mengk): add this to the database
|
|
127
|
+
# No need right now because multi-turn is still very experimental.
|
|
128
|
+
rollout_type: Literal["single_turn", "multi_turn"] = "single_turn"
|
|
129
|
+
|
|
130
|
+
# Default instructions for the judge
|
|
131
|
+
system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
132
|
+
citation_instructions: str = DEFAULT_CITATION_INSTRUCTIONS
|
|
133
|
+
output_schema: dict[str, Any] = DEFAULT_JUDGE_OUTPUT_SCHEMA
|
|
134
|
+
|
|
135
|
+
# How to run the judge
|
|
136
|
+
judge_model: ModelOption = DEFAULT_JUDGE_MODEL
|
|
137
|
+
|
|
138
|
+
def materialize_system_prompt(self, agent_run: AgentRun) -> str:
|
|
139
|
+
"""Construct the full prompt text for rubric evaluation.
|
|
140
|
+
|
|
141
|
+
This is the canonical implementation of prompt construction - use this function
|
|
142
|
+
anywhere you need to construct a rubric evaluation prompt (including cost estimation).
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
output_schema_text = json.dumps(self.output_schema, indent=2)
|
|
146
|
+
|
|
147
|
+
# We've already validated that the system prompt template has these keys
|
|
148
|
+
prompt = self.system_prompt_template.format(
|
|
149
|
+
rubric=self.rubric_text,
|
|
150
|
+
agent_run=agent_run.to_text_new(),
|
|
151
|
+
output_schema=output_schema_text,
|
|
152
|
+
# Only include citation instructions if the schema requests citations
|
|
153
|
+
citation_instructions=(
|
|
154
|
+
self.citation_instructions if _schema_requests_citations(self.output_schema) else ""
|
|
155
|
+
),
|
|
156
|
+
).strip()
|
|
157
|
+
|
|
158
|
+
return prompt
|
|
159
|
+
|
|
160
|
+
@field_validator("system_prompt_template")
|
|
161
|
+
@classmethod
|
|
162
|
+
def validate_system_prompt_template(cls, system_prompt_template: str):
|
|
163
|
+
# Extract all field names from the template
|
|
164
|
+
formatter = Formatter()
|
|
165
|
+
field_names = {
|
|
166
|
+
field_name
|
|
167
|
+
for _, field_name, _, _ in formatter.parse(system_prompt_template)
|
|
168
|
+
if field_name is not None
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
# Check for required fields
|
|
172
|
+
required_fields = {"agent_run", "output_schema", "rubric", "citation_instructions"}
|
|
173
|
+
missing_fields = required_fields - field_names
|
|
174
|
+
|
|
175
|
+
if missing_fields:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f"system_prompt_template must contain the following placeholders: {missing_fields}"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return system_prompt_template
|
|
181
|
+
|
|
182
|
+
@field_validator("output_schema")
|
|
183
|
+
@classmethod
|
|
184
|
+
def validate_output_schema(cls, output_schema: dict[str, Any]):
|
|
185
|
+
"""
|
|
186
|
+
Raises:
|
|
187
|
+
jsonschema.ValidationError: If the schema is invalid
|
|
188
|
+
jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
|
|
189
|
+
"""
|
|
190
|
+
validate_judge_result_schema(output_schema)
|
|
191
|
+
return output_schema
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class MultiTurnRubric(Rubric):
|
|
195
|
+
system_prompt_template: str = DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
196
|
+
rollout_type: Literal["single_turn", "multi_turn"] = "multi_turn"
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ExposedReasoningRubric(Rubric):
|
|
200
|
+
system_prompt_template: str = DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class ResultType(enum.Enum):
|
|
204
|
+
"""Enum for the type of result that a judge result can have."""
|
|
205
|
+
|
|
206
|
+
DIRECT_RESULT = "direct_result"
|
|
207
|
+
NEAR_MISS = "near_miss"
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class JudgeResult(BaseModel):
|
|
211
|
+
class Config:
|
|
212
|
+
frozen = True
|
|
213
|
+
|
|
214
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
215
|
+
agent_run_id: str
|
|
216
|
+
rubric_id: str
|
|
217
|
+
rubric_version: int
|
|
218
|
+
|
|
219
|
+
# Outputs
|
|
220
|
+
output: dict[str, Any]
|
|
221
|
+
result_metadata: dict[str, Any] | None = None
|
|
222
|
+
result_type: ResultType
|
|
223
|
+
|
|
224
|
+
# Deprecated
|
|
225
|
+
value: str | None = None
|
|
226
|
+
|
|
227
|
+
@field_serializer("result_type")
|
|
228
|
+
def serialize_result_type(self, result_type: ResultType) -> str:
|
|
229
|
+
return result_type.value
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class JudgeResultWithCitations(JudgeResult):
|
|
233
|
+
@classmethod
|
|
234
|
+
def from_judge_result(
|
|
235
|
+
cls, result: JudgeResult, schema: dict[str, Any]
|
|
236
|
+
) -> "JudgeResultWithCitations":
|
|
237
|
+
"""Judge result must be validated against the schema before calling this function!"""
|
|
238
|
+
|
|
239
|
+
def _parse_citation_string(output: str) -> dict[str, Any]:
|
|
240
|
+
text, citations = parse_citations(output)
|
|
241
|
+
return {"text": text, "citations": citations}
|
|
242
|
+
|
|
243
|
+
data = result.model_dump()
|
|
244
|
+
try:
|
|
245
|
+
data["output"] = traverse_schema_and_transform(
|
|
246
|
+
data["output"], schema, _parse_citation_string
|
|
247
|
+
)
|
|
248
|
+
except Exception as e:
|
|
249
|
+
logger.error(f"Failed to parse citations: {e}")
|
|
250
|
+
logger.error(f"Output: {data['output']}")
|
|
251
|
+
data["output"] = {"raw": data["output"]}
|
|
252
|
+
return cls(**data)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class JudgeResultCompletionCallback(Protocol):
|
|
256
|
+
"""Called when some batch of judge results is completed.
|
|
257
|
+
Supports batched calls for cases where many results are pre-computed.
|
|
258
|
+
This avoids invoking the callback separately for each datapoint.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
async def __call__(
|
|
262
|
+
self,
|
|
263
|
+
batch_index: int,
|
|
264
|
+
judge_results: list[JudgeResult] | None,
|
|
265
|
+
) -> None: ...
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def traverse_schema_and_transform(
|
|
269
|
+
output: Any,
|
|
270
|
+
schema: dict[str, Any],
|
|
271
|
+
citation_string_handler: Callable[[str], Any],
|
|
272
|
+
) -> Any:
|
|
273
|
+
"""Recursively traverse output based on schema, applying citation_string_handler to citation strings."""
|
|
274
|
+
if schema.get("type") == "string" and schema.get("citations"): # type: ignore
|
|
275
|
+
return citation_string_handler(output)
|
|
276
|
+
elif schema.get("type") == "object":
|
|
277
|
+
properties: dict[str, Any] = schema.get("properties", {})
|
|
278
|
+
result: dict[str, Any] = {}
|
|
279
|
+
for key in properties:
|
|
280
|
+
if key in output:
|
|
281
|
+
result[key] = traverse_schema_and_transform(
|
|
282
|
+
output[key], properties[key], citation_string_handler
|
|
283
|
+
)
|
|
284
|
+
return result
|
|
285
|
+
elif schema.get("type") == "array":
|
|
286
|
+
item_schema: dict[str, Any] = schema.get("items", {})
|
|
287
|
+
return [
|
|
288
|
+
traverse_schema_and_transform(item, item_schema, citation_string_handler)
|
|
289
|
+
for item in output
|
|
290
|
+
]
|
|
291
|
+
else:
|
|
292
|
+
return output
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _schema_requests_citations(schema: dict[str, Any]) -> bool:
|
|
296
|
+
"""Check if any field in the schema requests citations by having 'citations': 'true'."""
|
|
297
|
+
|
|
298
|
+
def _check_field(field_schema: Any) -> bool:
|
|
299
|
+
if isinstance(field_schema, dict):
|
|
300
|
+
if field_schema.get("citations"): # type: ignore
|
|
301
|
+
return True
|
|
302
|
+
for value in field_schema.values(): # type: ignore
|
|
303
|
+
if isinstance(value, dict) and _check_field(value):
|
|
304
|
+
return True
|
|
305
|
+
elif isinstance(value, list):
|
|
306
|
+
for item in value: # type: ignore
|
|
307
|
+
if isinstance(item, dict) and _check_field(item):
|
|
308
|
+
return True
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
return _check_field(schema)
|