docent-python 0.1.21a0__py3-none-any.whl → 0.1.23a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/data_models/llm_output.py +3 -0
- docent/_llm_util/llm_cache.py +4 -4
- docent/_llm_util/{prod_llms.py → llm_svc.py} +104 -86
- docent/_llm_util/providers/preference_types.py +2 -2
- docent/data_models/__init__.py +2 -2
- docent/data_models/judge.py +7 -4
- docent/judges/__init__.py +2 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +484 -119
- docent/judges/runner.py +66 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +73 -2
- docent/judges/util/meta_schema.json +3 -1
- docent/judges/util/parse_output.py +8 -16
- docent/judges/util/voting.py +61 -6
- docent/sdk/client.py +72 -41
- docent/trace.py +18 -0
- {docent_python-0.1.21a0.dist-info → docent_python-0.1.23a0.dist-info}/METADATA +2 -1
- {docent_python-0.1.21a0.dist-info → docent_python-0.1.23a0.dist-info}/RECORD +21 -20
- docent/_llm_util/data_models/simple_svc.py +0 -79
- docent/trace_2.py +0 -1842
- {docent_python-0.1.21a0.dist-info → docent_python-0.1.23a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.21a0.dist-info → docent_python-0.1.23a0.dist-info}/licenses/LICENSE.md +0 -0
docent/judges/runner.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import anyio
|
|
2
|
+
from tqdm.auto import tqdm
|
|
3
|
+
|
|
4
|
+
from docent._llm_util.llm_svc import BaseLLMService
|
|
5
|
+
from docent._log_util import get_logger
|
|
6
|
+
from docent.data_models.agent_run import AgentRun
|
|
7
|
+
from docent.judges import (
|
|
8
|
+
JudgeResult,
|
|
9
|
+
JudgeResultCompletionCallback,
|
|
10
|
+
Rubric,
|
|
11
|
+
)
|
|
12
|
+
from docent.judges.impl import build_judge
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def run_rubric(
|
|
18
|
+
agent_runs: list[AgentRun],
|
|
19
|
+
rubric: Rubric,
|
|
20
|
+
llm_svc: BaseLLMService,
|
|
21
|
+
callback: JudgeResultCompletionCallback | None = None,
|
|
22
|
+
*,
|
|
23
|
+
show_progress: bool = True,
|
|
24
|
+
) -> list[JudgeResult | None]:
|
|
25
|
+
if not agent_runs:
|
|
26
|
+
raise ValueError("agent_runs must be a non-empty sequence")
|
|
27
|
+
if rubric.n_rollouts_per_input <= 0:
|
|
28
|
+
raise ValueError("rubric.n_rollouts_per_input must be greater than 0")
|
|
29
|
+
|
|
30
|
+
judge = build_judge(rubric, llm_svc)
|
|
31
|
+
|
|
32
|
+
logger.info(
|
|
33
|
+
"Running rubric %s version %s against %d agent runs",
|
|
34
|
+
rubric.id,
|
|
35
|
+
rubric.version,
|
|
36
|
+
len(agent_runs),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
agent_results: list[JudgeResult | None] = [None for _ in agent_runs]
|
|
40
|
+
progress_bar = tqdm(
|
|
41
|
+
total=len(agent_runs), desc=f"Rubric {rubric.id}", disable=not show_progress
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
async def _run_single_judge(index: int, agent_run: AgentRun):
|
|
45
|
+
agent_results[index] = result = await judge(agent_run)
|
|
46
|
+
|
|
47
|
+
if callback is not None:
|
|
48
|
+
await callback(index, [result] if result is not None else None)
|
|
49
|
+
progress_bar.update()
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
async with anyio.create_task_group() as tg:
|
|
53
|
+
for index, agent_run in enumerate(agent_runs):
|
|
54
|
+
tg.start_soon(_run_single_judge, index, agent_run)
|
|
55
|
+
finally:
|
|
56
|
+
progress_bar.close()
|
|
57
|
+
|
|
58
|
+
successful = sum(result is not None for result in agent_results)
|
|
59
|
+
logger.info(
|
|
60
|
+
"Finished rubric %s: produced %d/%d judge results",
|
|
61
|
+
rubric.id,
|
|
62
|
+
successful,
|
|
63
|
+
len(agent_results),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return agent_results
|
docent/judges/stats.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from typing import Iterator, List, Tuple
|
|
2
|
+
|
|
3
|
+
from scipy import stats
|
|
4
|
+
|
|
5
|
+
from docent._log_util import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def print_stats_with_intervals(name: str, mean: float, std: float, confidence_levels: list[float]):
|
|
11
|
+
"""Print statistics with confidence intervals at multiple confidence levels.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
name: Name of the statistic
|
|
15
|
+
mean: Mean value
|
|
16
|
+
std: Standard deviation
|
|
17
|
+
confidence_levels: List of confidence levels (e.g., [0.90, 0.95, 0.99])
|
|
18
|
+
"""
|
|
19
|
+
intervals_str = ", ".join(
|
|
20
|
+
[
|
|
21
|
+
f"{int(level*100)}% interval: [{mean - stats.norm.ppf((1+level)/2) * std:.4f}, {mean + stats.norm.ppf((1+level)/2) * std:.4f}]" # type: ignore
|
|
22
|
+
for level in confidence_levels
|
|
23
|
+
]
|
|
24
|
+
)
|
|
25
|
+
print(f"{name} mean: {mean:.4f}, std: {std:.4f}, {intervals_str}")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _bounded_compositions(total: int, parts: int, bound: int) -> Iterator[Tuple[int, ...]]:
|
|
29
|
+
"""
|
|
30
|
+
Yield all tuples (x1,...,x_parts) of nonnegative ints summing to `total`
|
|
31
|
+
with each xk <= bound.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Recursive backtracking with pruning by remaining capacity.
|
|
35
|
+
def rec(k: int, remaining: int, prefix: List[int]) -> Iterator[Tuple[int, ...]]:
|
|
36
|
+
if k == parts:
|
|
37
|
+
if remaining == 0:
|
|
38
|
+
yield tuple(prefix)
|
|
39
|
+
return
|
|
40
|
+
# Max we can put here is min(bound, remaining - min_needed_for_rest)
|
|
41
|
+
# The min needed for the rest is 0; also cannot exceed remaining.
|
|
42
|
+
max_here = min(bound, remaining)
|
|
43
|
+
# Optional pruning: ensure the rest can absorb what's left (always true since min=0)
|
|
44
|
+
for x in range(max_here + 1):
|
|
45
|
+
prefix.append(x)
|
|
46
|
+
yield from rec(k + 1, remaining - x, prefix)
|
|
47
|
+
prefix.pop()
|
|
48
|
+
|
|
49
|
+
yield from rec(0, total, [])
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def plurality_vectors(m: int, K: int, i: int) -> Iterator[Tuple[int, ...]]:
|
|
53
|
+
"""
|
|
54
|
+
Generate all count vectors n = (n1,...,nm) of nonnegative integers with
|
|
55
|
+
sum(n) = K and STRICT plurality at index i:
|
|
56
|
+
n[i] > n[j] for all j != i.
|
|
57
|
+
|
|
58
|
+
Yields vectors in no particular order.
|
|
59
|
+
"""
|
|
60
|
+
if not (0 <= i < m):
|
|
61
|
+
raise ValueError("i must be in [0, m).")
|
|
62
|
+
if m < 2 or K < 1:
|
|
63
|
+
return # nothing to yield in degenerate cases
|
|
64
|
+
|
|
65
|
+
for ni in range(1, K + 1): # at least 1 vote for the winner
|
|
66
|
+
rest_total = K - ni
|
|
67
|
+
cap = ni - 1 # strict plurality: others must be <= ni-1
|
|
68
|
+
# If cap < 0 but rest_total > 0, impossible
|
|
69
|
+
if cap < 0 and rest_total > 0:
|
|
70
|
+
continue
|
|
71
|
+
# Build the other m-1 counts under the cap
|
|
72
|
+
for others in _bounded_compositions(rest_total, m - 1, cap):
|
|
73
|
+
# Stitch back in ni at position i
|
|
74
|
+
vec = list(others[:i]) + [ni] + list(others[i:])
|
|
75
|
+
yield tuple(vec)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def p_mode(n: int, p_v: list[float], idx: int) -> float:
|
|
79
|
+
"""Probability that the modal sample of sampling Multinom(n, p_v) is the idxth one."""
|
|
80
|
+
count_vecs = plurality_vectors(len(p_v), n, idx)
|
|
81
|
+
return sum(stats.multinomial.pmf(vec, n, p_v) for vec in count_vecs) # type: ignore
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# async def analyze_majority_judge(
|
|
85
|
+
# rubric: Rubric,
|
|
86
|
+
# agent_runs: list[AgentRun],
|
|
87
|
+
# matched_labels: dict[str, dict[str, Any]], # agent_run_id -> gold label obj
|
|
88
|
+
# results_path: Path,
|
|
89
|
+
# samples_per_agent_run: int = 10,
|
|
90
|
+
# maj_k: int = 5, # Does not affect data collection
|
|
91
|
+
# max_llm_concurrency: int = 100,
|
|
92
|
+
# ):
|
|
93
|
+
# # if rubric.n_rollouts_per_input != 1:
|
|
94
|
+
# # raise ValueError("You should use n_rollouts_per_input=1")
|
|
95
|
+
|
|
96
|
+
# if not results_path.exists():
|
|
97
|
+
# logger.info(f"Evaluating rubrics and saving results to {results_path}")
|
|
98
|
+
|
|
99
|
+
# max_conc_per_rubric = min(
|
|
100
|
+
# max_llm_concurrency, len(agent_runs) * rubric.n_rollouts_per_input
|
|
101
|
+
# )
|
|
102
|
+
# max_parallel_rubrics = max(1, max_llm_concurrency // max_conc_per_rubric)
|
|
103
|
+
# logger.info(
|
|
104
|
+
# f"Evaluating {samples_per_agent_run} samples per agent run, {max_conc_per_rubric} concurrent LLM calls per rubric, {max_parallel_rubrics} parallel rubrics"
|
|
105
|
+
# )
|
|
106
|
+
|
|
107
|
+
# await evaluate_multiple_rubrics(
|
|
108
|
+
# rubrics=[rubric] * samples_per_agent_run,
|
|
109
|
+
# agent_runs=agent_runs,
|
|
110
|
+
# llm_svc=SimpleLLMService(),
|
|
111
|
+
# output_path=results_path,
|
|
112
|
+
# max_concurrent_llm_calls_per_rubric=max_conc_per_rubric,
|
|
113
|
+
# max_parallel_rubrics=max_parallel_rubrics,
|
|
114
|
+
# )
|
|
115
|
+
# else:
|
|
116
|
+
# logger.info(f"Found existing results at {results_path}, loading them")
|
|
117
|
+
|
|
118
|
+
# rows = load_rubric_results_from_file(results_path)
|
|
119
|
+
|
|
120
|
+
# # Parse results into a flat dataframe
|
|
121
|
+
# parsed_results: list[dict[str, Any]] = []
|
|
122
|
+
# for row in rows:
|
|
123
|
+
# config_key = row.rubric.model_dump_json(
|
|
124
|
+
# exclude={
|
|
125
|
+
# "rubric_text",
|
|
126
|
+
# "id",
|
|
127
|
+
# "version",
|
|
128
|
+
# "system_prompt_template",
|
|
129
|
+
# "citation_instructions",
|
|
130
|
+
# "output_schema",
|
|
131
|
+
# }
|
|
132
|
+
# )
|
|
133
|
+
# for result in row.judge_results:
|
|
134
|
+
# if result is None:
|
|
135
|
+
# print("found null")
|
|
136
|
+
# continue
|
|
137
|
+
# # raise ValueError("result is None")
|
|
138
|
+
|
|
139
|
+
# gold_label = matched_labels[result.agent_run_id]["label"]
|
|
140
|
+
# judge_label = result.output["label"]
|
|
141
|
+
# md = result.result_metadata or {}
|
|
142
|
+
|
|
143
|
+
# parsed_results.append(
|
|
144
|
+
# {
|
|
145
|
+
# "config_key": config_key,
|
|
146
|
+
# "agent_run_id": result.agent_run_id,
|
|
147
|
+
# "gold_label": gold_label,
|
|
148
|
+
# "judge_label": judge_label,
|
|
149
|
+
# "final_results": md.get("final_results"),
|
|
150
|
+
# "is_correct": gold_label == judge_label,
|
|
151
|
+
# }
|
|
152
|
+
# )
|
|
153
|
+
|
|
154
|
+
# df = pd.DataFrame(parsed_results)
|
|
155
|
+
# potential_labels = ["cheating", "future leakage", "not suspicious"]
|
|
156
|
+
|
|
157
|
+
# def _get_pred_dist(x: pd.Series):
|
|
158
|
+
# x = x.tolist()
|
|
159
|
+
# counts = [x.count(l) for l in potential_labels]
|
|
160
|
+
# assert np.sum(counts) == len(x)
|
|
161
|
+
# return {l: counts[i] / len(x) for i, l in enumerate(potential_labels)}
|
|
162
|
+
|
|
163
|
+
# n_ars = len(df.groupby("agent_run_id").count())
|
|
164
|
+
# p_correct = (
|
|
165
|
+
# df.groupby("agent_run_id")
|
|
166
|
+
# .agg(
|
|
167
|
+
# {
|
|
168
|
+
# "gold_label": lambda x: x.iloc[0],
|
|
169
|
+
# "judge_label": _get_pred_dist,
|
|
170
|
+
# "is_correct": np.mean,
|
|
171
|
+
# }
|
|
172
|
+
# )
|
|
173
|
+
# .rename(columns={"judge_label": "pred_dist", "is_correct": "p_correct_naive"})
|
|
174
|
+
# )
|
|
175
|
+
# p_correct["p_correct_naive_var"] = p_correct["p_correct_naive"].apply(lambda x: x * (1 - x))
|
|
176
|
+
|
|
177
|
+
# p_correct["p_correct_majority"] = p_correct.apply(
|
|
178
|
+
# lambda row: p_mode(
|
|
179
|
+
# maj_k,
|
|
180
|
+
# [row["pred_dist"][l] for l in potential_labels],
|
|
181
|
+
# potential_labels.index(row["gold_label"]),
|
|
182
|
+
# ),
|
|
183
|
+
# axis=1,
|
|
184
|
+
# )
|
|
185
|
+
# p_correct["p_correct_majority_var"] = p_correct["p_correct_majority"].apply(
|
|
186
|
+
# lambda x: x * (1 - x)
|
|
187
|
+
# )
|
|
188
|
+
# p_correct.sort_values(by="p_correct_majority_var", ascending=False, inplace=True)
|
|
189
|
+
|
|
190
|
+
# overall_naive_mean = p_correct["p_correct_naive"].mean()
|
|
191
|
+
# overall_naive_std = np.sqrt(p_correct["p_correct_naive_var"].sum() / n_ars**2)
|
|
192
|
+
# overall_majority_mean = p_correct["p_correct_majority"].mean()
|
|
193
|
+
# overall_majority_std = np.sqrt(p_correct["p_correct_majority_var"].sum() / n_ars**2)
|
|
194
|
+
|
|
195
|
+
# confidence_levels = [0.5, 0.95]
|
|
196
|
+
# print_stats_with_intervals(
|
|
197
|
+
# "Overall naive", overall_naive_mean, overall_naive_std, confidence_levels
|
|
198
|
+
# )
|
|
199
|
+
# print_stats_with_intervals(
|
|
200
|
+
# f"Overall majority (k={maj_k})",
|
|
201
|
+
# overall_majority_mean,
|
|
202
|
+
# overall_majority_std,
|
|
203
|
+
# confidence_levels,
|
|
204
|
+
# )
|
|
205
|
+
# return p_correct
|
docent/judges/types.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import json
|
|
3
3
|
from string import Formatter
|
|
4
|
-
from typing import Any, Callable, Protocol
|
|
4
|
+
from typing import Any, Callable, Literal, Protocol
|
|
5
5
|
from uuid import uuid4
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
@@ -19,12 +19,64 @@ DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
|
19
19
|
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
20
20
|
|
|
21
21
|
Rubric:
|
|
22
|
+
<rubric>
|
|
22
23
|
{rubric}
|
|
24
|
+
</rubric>
|
|
23
25
|
|
|
24
26
|
Agent run:
|
|
27
|
+
<agent_run>
|
|
25
28
|
{agent_run}
|
|
29
|
+
</agent_run>
|
|
26
30
|
|
|
27
|
-
Your
|
|
31
|
+
Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step.
|
|
32
|
+
|
|
33
|
+
When you are finished, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
34
|
+
|
|
35
|
+
The JSON object you produce must adhere to the following schema:
|
|
36
|
+
{output_schema}
|
|
37
|
+
|
|
38
|
+
{citation_instructions}
|
|
39
|
+
""".strip()
|
|
40
|
+
|
|
41
|
+
DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
42
|
+
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
43
|
+
|
|
44
|
+
Rubric:
|
|
45
|
+
<rubric>
|
|
46
|
+
{rubric}
|
|
47
|
+
</rubric>
|
|
48
|
+
|
|
49
|
+
Agent run:
|
|
50
|
+
<agent_run>
|
|
51
|
+
{agent_run}
|
|
52
|
+
</agent_run>
|
|
53
|
+
|
|
54
|
+
Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must execute **one step in the decision procedure per assistant message turn**. After each turn, output a complete and detailed recount of all actions you took, and everything you discovered. Then call the `step_finished` tool.
|
|
55
|
+
|
|
56
|
+
When you are finished going through the decision procedure, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
57
|
+
|
|
58
|
+
The JSON object you produce must adhere to the following schema:
|
|
59
|
+
{output_schema}
|
|
60
|
+
|
|
61
|
+
{citation_instructions}
|
|
62
|
+
""".strip()
|
|
63
|
+
|
|
64
|
+
DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
65
|
+
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
66
|
+
|
|
67
|
+
Rubric:
|
|
68
|
+
<rubric>
|
|
69
|
+
{rubric}
|
|
70
|
+
</rubric>
|
|
71
|
+
|
|
72
|
+
Agent run:
|
|
73
|
+
<agent_run>
|
|
74
|
+
{agent_run}
|
|
75
|
+
</agent_run>
|
|
76
|
+
|
|
77
|
+
Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must *fully externalize* your reasoning work by outputting details in the assistant message, surrounded by <reasoning>...</reasoning> tags. The reasoning section can be as messy as you need. You should use *high* reasoning effort.
|
|
78
|
+
|
|
79
|
+
When you are finished, output your final adjudication in the assistant message, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
28
80
|
|
|
29
81
|
The JSON object you produce must adhere to the following schema:
|
|
30
82
|
{output_schema}
|
|
@@ -51,6 +103,11 @@ DEFAULT_JUDGE_OUTPUT_SCHEMA = {
|
|
|
51
103
|
DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
|
|
52
104
|
|
|
53
105
|
|
|
106
|
+
class JudgeVariant(str, enum.Enum):
|
|
107
|
+
MAJORITY = "majority"
|
|
108
|
+
MULTI_REFLECT = "multi-reflect"
|
|
109
|
+
|
|
110
|
+
|
|
54
111
|
class Rubric(BaseModel):
|
|
55
112
|
"""TODO(mengk): this should really be called JudgeConfig,
|
|
56
113
|
but temporarily keeping this for consistency with docent_core."""
|
|
@@ -64,6 +121,11 @@ class Rubric(BaseModel):
|
|
|
64
121
|
|
|
65
122
|
# What the judge actually does
|
|
66
123
|
rubric_text: str
|
|
124
|
+
n_rollouts_per_input: int = 1
|
|
125
|
+
judge_variant: JudgeVariant = JudgeVariant.MAJORITY
|
|
126
|
+
# TODO(mengk): add this to the database
|
|
127
|
+
# No need right now because multi-turn is still very experimental.
|
|
128
|
+
rollout_type: Literal["single_turn", "multi_turn"] = "single_turn"
|
|
67
129
|
|
|
68
130
|
# Default instructions for the judge
|
|
69
131
|
system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
@@ -129,6 +191,15 @@ class Rubric(BaseModel):
|
|
|
129
191
|
return output_schema
|
|
130
192
|
|
|
131
193
|
|
|
194
|
+
class MultiTurnRubric(Rubric):
|
|
195
|
+
system_prompt_template: str = DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
196
|
+
rollout_type: Literal["single_turn", "multi_turn"] = "multi_turn"
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ExposedReasoningRubric(Rubric):
|
|
200
|
+
system_prompt_template: str = DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
201
|
+
|
|
202
|
+
|
|
132
203
|
class ResultType(enum.Enum):
|
|
133
204
|
"""Enum for the type of result that a judge result can have."""
|
|
134
205
|
|
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
import json
|
|
2
1
|
from typing import Any, cast
|
|
3
2
|
|
|
4
3
|
import jsonschema
|
|
5
4
|
|
|
6
5
|
from docent._llm_util.data_models.exceptions import ValidationFailedException
|
|
7
|
-
from docent._llm_util.data_models.llm_output import LLMOutput
|
|
8
6
|
from docent._log_util import get_logger
|
|
9
7
|
from docent.data_models.agent_run import AgentRun
|
|
10
8
|
from docent.data_models.remove_invalid_citation_ranges import remove_invalid_citation_ranges
|
|
@@ -55,10 +53,8 @@ def _validate_rubric_output(
|
|
|
55
53
|
)
|
|
56
54
|
|
|
57
55
|
|
|
58
|
-
def
|
|
59
|
-
|
|
60
|
-
output_schema: dict[str, Any],
|
|
61
|
-
agent_run: AgentRun,
|
|
56
|
+
def parse_and_validate_output_str(
|
|
57
|
+
output_str: str, output_schema: dict[str, Any], agent_run: AgentRun
|
|
62
58
|
) -> dict[str, Any]:
|
|
63
59
|
"""Parse and validate LLM output for rubric evaluation.
|
|
64
60
|
|
|
@@ -73,23 +69,19 @@ def parse_and_validate_llm_output(
|
|
|
73
69
|
Raises:
|
|
74
70
|
ValidationFailedException: If parsing or validation fails
|
|
75
71
|
"""
|
|
76
|
-
if llm_output.first_text is None:
|
|
77
|
-
raise ValidationFailedException("LLM output has no text", failed_output=None)
|
|
78
72
|
|
|
79
73
|
try:
|
|
80
|
-
output = forgiving_json_loads(
|
|
81
|
-
except
|
|
74
|
+
output = forgiving_json_loads(output_str)
|
|
75
|
+
except Exception as e:
|
|
82
76
|
raise ValidationFailedException(
|
|
83
|
-
f"Failed to parse JSON: {e}. Raw text: `{
|
|
84
|
-
failed_output=
|
|
77
|
+
f"Failed to parse JSON: {e}. Raw text: `{output_str}`",
|
|
78
|
+
failed_output=output_str,
|
|
85
79
|
)
|
|
86
80
|
|
|
87
81
|
if not isinstance(output, dict):
|
|
88
|
-
logger.error(f"Expected dict output, got {type(output)}")
|
|
89
|
-
logger.error(f"LLM output: {llm_output.first_text}")
|
|
90
82
|
raise ValidationFailedException(
|
|
91
|
-
f"Expected dict output, got {type(output)}. Raw text: {
|
|
92
|
-
failed_output=
|
|
83
|
+
f"Expected dict output, got {type(output)}. Raw text: {output_str}",
|
|
84
|
+
failed_output=output_str,
|
|
93
85
|
)
|
|
94
86
|
|
|
95
87
|
return _validate_rubric_output(cast(dict[str, Any], output), output_schema, agent_run)
|
docent/judges/util/voting.py
CHANGED
|
@@ -1,11 +1,23 @@
|
|
|
1
1
|
from collections import Counter
|
|
2
|
-
from typing import Any, cast
|
|
2
|
+
from typing import Any, TypedDict, cast
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EstimateWithCI(TypedDict):
|
|
8
|
+
mean: float
|
|
9
|
+
var: float
|
|
10
|
+
n: int
|
|
11
|
+
ci_95: float
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
|
|
3
15
|
|
|
4
16
|
|
|
5
17
|
def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
|
|
6
18
|
"""Get list of top-level keys in schema that we want to measure agreement on.
|
|
7
19
|
|
|
8
|
-
This includes enum
|
|
20
|
+
This includes enum and bool fields.
|
|
9
21
|
|
|
10
22
|
Args:
|
|
11
23
|
schema: JSON schema dict
|
|
@@ -29,10 +41,7 @@ def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
|
|
|
29
41
|
# Include boolean fields
|
|
30
42
|
if field_type == "boolean":
|
|
31
43
|
agreement_keys.append(key)
|
|
32
|
-
# Include
|
|
33
|
-
elif field_type == "integer":
|
|
34
|
-
agreement_keys.append(key)
|
|
35
|
-
# Include enum fields (even strings)
|
|
44
|
+
# Include enum fields (strings and numbers must be in this category)
|
|
36
45
|
elif "enum" in field_schema:
|
|
37
46
|
agreement_keys.append(key)
|
|
38
47
|
|
|
@@ -82,3 +91,49 @@ def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[
|
|
|
82
91
|
max_idx = indep_result_scores.index(max(indep_result_scores))
|
|
83
92
|
|
|
84
93
|
return max_idx, agt_key_modes_and_counts
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def compute_output_distributions(
|
|
97
|
+
indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
|
|
98
|
+
):
|
|
99
|
+
def _get_possible_values(key: str) -> list[str | bool | int | float]:
|
|
100
|
+
if "enum" in output_schema.get("properties", {}).get(key, {}):
|
|
101
|
+
return output_schema.get("properties", {}).get(key, {}).get("enum", [])
|
|
102
|
+
elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
|
|
103
|
+
return [True, False]
|
|
104
|
+
else:
|
|
105
|
+
return []
|
|
106
|
+
|
|
107
|
+
raw_counts: dict[str, dict[str | bool | int | float, int]] = {
|
|
108
|
+
key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
|
|
109
|
+
}
|
|
110
|
+
# Collect counts for each possible value
|
|
111
|
+
for result in indep_results:
|
|
112
|
+
for key in agreement_keys:
|
|
113
|
+
if (value := result.get(key)) is not None: # Could be none if the key is optional
|
|
114
|
+
assert (
|
|
115
|
+
value in raw_counts[key]
|
|
116
|
+
), "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
|
|
117
|
+
raw_counts[key][value] += 1
|
|
118
|
+
|
|
119
|
+
distributions: dict[str, JudgeOutputDistribution] = {}
|
|
120
|
+
for agt_key in agreement_keys:
|
|
121
|
+
distributions[agt_key] = {}
|
|
122
|
+
|
|
123
|
+
# First normalize the counts to get probabilities
|
|
124
|
+
counts = raw_counts[agt_key]
|
|
125
|
+
total = sum(counts.values())
|
|
126
|
+
probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
|
|
127
|
+
|
|
128
|
+
for output_key, value in probs.items():
|
|
129
|
+
mean, estimate_var = value, (value * (1 - value))
|
|
130
|
+
# TODO(mengk): change to the wilson score interval
|
|
131
|
+
ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
|
|
132
|
+
distributions[agt_key][output_key] = {
|
|
133
|
+
"mean": mean,
|
|
134
|
+
"var": estimate_var,
|
|
135
|
+
"n": total,
|
|
136
|
+
"ci_95": ci_95,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return distributions
|