docent-python 0.1.21a0__py3-none-any.whl → 0.1.23a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

@@ -0,0 +1,66 @@
1
+ import anyio
2
+ from tqdm.auto import tqdm
3
+
4
+ from docent._llm_util.llm_svc import BaseLLMService
5
+ from docent._log_util import get_logger
6
+ from docent.data_models.agent_run import AgentRun
7
+ from docent.judges import (
8
+ JudgeResult,
9
+ JudgeResultCompletionCallback,
10
+ Rubric,
11
+ )
12
+ from docent.judges.impl import build_judge
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ async def run_rubric(
18
+ agent_runs: list[AgentRun],
19
+ rubric: Rubric,
20
+ llm_svc: BaseLLMService,
21
+ callback: JudgeResultCompletionCallback | None = None,
22
+ *,
23
+ show_progress: bool = True,
24
+ ) -> list[JudgeResult | None]:
25
+ if not agent_runs:
26
+ raise ValueError("agent_runs must be a non-empty sequence")
27
+ if rubric.n_rollouts_per_input <= 0:
28
+ raise ValueError("rubric.n_rollouts_per_input must be greater than 0")
29
+
30
+ judge = build_judge(rubric, llm_svc)
31
+
32
+ logger.info(
33
+ "Running rubric %s version %s against %d agent runs",
34
+ rubric.id,
35
+ rubric.version,
36
+ len(agent_runs),
37
+ )
38
+
39
+ agent_results: list[JudgeResult | None] = [None for _ in agent_runs]
40
+ progress_bar = tqdm(
41
+ total=len(agent_runs), desc=f"Rubric {rubric.id}", disable=not show_progress
42
+ )
43
+
44
+ async def _run_single_judge(index: int, agent_run: AgentRun):
45
+ agent_results[index] = result = await judge(agent_run)
46
+
47
+ if callback is not None:
48
+ await callback(index, [result] if result is not None else None)
49
+ progress_bar.update()
50
+
51
+ try:
52
+ async with anyio.create_task_group() as tg:
53
+ for index, agent_run in enumerate(agent_runs):
54
+ tg.start_soon(_run_single_judge, index, agent_run)
55
+ finally:
56
+ progress_bar.close()
57
+
58
+ successful = sum(result is not None for result in agent_results)
59
+ logger.info(
60
+ "Finished rubric %s: produced %d/%d judge results",
61
+ rubric.id,
62
+ successful,
63
+ len(agent_results),
64
+ )
65
+
66
+ return agent_results
docent/judges/stats.py ADDED
@@ -0,0 +1,205 @@
1
+ from typing import Iterator, List, Tuple
2
+
3
+ from scipy import stats
4
+
5
+ from docent._log_util import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ def print_stats_with_intervals(name: str, mean: float, std: float, confidence_levels: list[float]):
11
+ """Print statistics with confidence intervals at multiple confidence levels.
12
+
13
+ Args:
14
+ name: Name of the statistic
15
+ mean: Mean value
16
+ std: Standard deviation
17
+ confidence_levels: List of confidence levels (e.g., [0.90, 0.95, 0.99])
18
+ """
19
+ intervals_str = ", ".join(
20
+ [
21
+ f"{int(level*100)}% interval: [{mean - stats.norm.ppf((1+level)/2) * std:.4f}, {mean + stats.norm.ppf((1+level)/2) * std:.4f}]" # type: ignore
22
+ for level in confidence_levels
23
+ ]
24
+ )
25
+ print(f"{name} mean: {mean:.4f}, std: {std:.4f}, {intervals_str}")
26
+
27
+
28
+ def _bounded_compositions(total: int, parts: int, bound: int) -> Iterator[Tuple[int, ...]]:
29
+ """
30
+ Yield all tuples (x1,...,x_parts) of nonnegative ints summing to `total`
31
+ with each xk <= bound.
32
+ """
33
+
34
+ # Recursive backtracking with pruning by remaining capacity.
35
+ def rec(k: int, remaining: int, prefix: List[int]) -> Iterator[Tuple[int, ...]]:
36
+ if k == parts:
37
+ if remaining == 0:
38
+ yield tuple(prefix)
39
+ return
40
+ # Max we can put here is min(bound, remaining - min_needed_for_rest)
41
+ # The min needed for the rest is 0; also cannot exceed remaining.
42
+ max_here = min(bound, remaining)
43
+ # Optional pruning: ensure the rest can absorb what's left (always true since min=0)
44
+ for x in range(max_here + 1):
45
+ prefix.append(x)
46
+ yield from rec(k + 1, remaining - x, prefix)
47
+ prefix.pop()
48
+
49
+ yield from rec(0, total, [])
50
+
51
+
52
+ def plurality_vectors(m: int, K: int, i: int) -> Iterator[Tuple[int, ...]]:
53
+ """
54
+ Generate all count vectors n = (n1,...,nm) of nonnegative integers with
55
+ sum(n) = K and STRICT plurality at index i:
56
+ n[i] > n[j] for all j != i.
57
+
58
+ Yields vectors in no particular order.
59
+ """
60
+ if not (0 <= i < m):
61
+ raise ValueError("i must be in [0, m).")
62
+ if m < 2 or K < 1:
63
+ return # nothing to yield in degenerate cases
64
+
65
+ for ni in range(1, K + 1): # at least 1 vote for the winner
66
+ rest_total = K - ni
67
+ cap = ni - 1 # strict plurality: others must be <= ni-1
68
+ # If cap < 0 but rest_total > 0, impossible
69
+ if cap < 0 and rest_total > 0:
70
+ continue
71
+ # Build the other m-1 counts under the cap
72
+ for others in _bounded_compositions(rest_total, m - 1, cap):
73
+ # Stitch back in ni at position i
74
+ vec = list(others[:i]) + [ni] + list(others[i:])
75
+ yield tuple(vec)
76
+
77
+
78
+ def p_mode(n: int, p_v: list[float], idx: int) -> float:
79
+ """Probability that the modal sample of sampling Multinom(n, p_v) is the idxth one."""
80
+ count_vecs = plurality_vectors(len(p_v), n, idx)
81
+ return sum(stats.multinomial.pmf(vec, n, p_v) for vec in count_vecs) # type: ignore
82
+
83
+
84
+ # async def analyze_majority_judge(
85
+ # rubric: Rubric,
86
+ # agent_runs: list[AgentRun],
87
+ # matched_labels: dict[str, dict[str, Any]], # agent_run_id -> gold label obj
88
+ # results_path: Path,
89
+ # samples_per_agent_run: int = 10,
90
+ # maj_k: int = 5, # Does not affect data collection
91
+ # max_llm_concurrency: int = 100,
92
+ # ):
93
+ # # if rubric.n_rollouts_per_input != 1:
94
+ # # raise ValueError("You should use n_rollouts_per_input=1")
95
+
96
+ # if not results_path.exists():
97
+ # logger.info(f"Evaluating rubrics and saving results to {results_path}")
98
+
99
+ # max_conc_per_rubric = min(
100
+ # max_llm_concurrency, len(agent_runs) * rubric.n_rollouts_per_input
101
+ # )
102
+ # max_parallel_rubrics = max(1, max_llm_concurrency // max_conc_per_rubric)
103
+ # logger.info(
104
+ # f"Evaluating {samples_per_agent_run} samples per agent run, {max_conc_per_rubric} concurrent LLM calls per rubric, {max_parallel_rubrics} parallel rubrics"
105
+ # )
106
+
107
+ # await evaluate_multiple_rubrics(
108
+ # rubrics=[rubric] * samples_per_agent_run,
109
+ # agent_runs=agent_runs,
110
+ # llm_svc=SimpleLLMService(),
111
+ # output_path=results_path,
112
+ # max_concurrent_llm_calls_per_rubric=max_conc_per_rubric,
113
+ # max_parallel_rubrics=max_parallel_rubrics,
114
+ # )
115
+ # else:
116
+ # logger.info(f"Found existing results at {results_path}, loading them")
117
+
118
+ # rows = load_rubric_results_from_file(results_path)
119
+
120
+ # # Parse results into a flat dataframe
121
+ # parsed_results: list[dict[str, Any]] = []
122
+ # for row in rows:
123
+ # config_key = row.rubric.model_dump_json(
124
+ # exclude={
125
+ # "rubric_text",
126
+ # "id",
127
+ # "version",
128
+ # "system_prompt_template",
129
+ # "citation_instructions",
130
+ # "output_schema",
131
+ # }
132
+ # )
133
+ # for result in row.judge_results:
134
+ # if result is None:
135
+ # print("found null")
136
+ # continue
137
+ # # raise ValueError("result is None")
138
+
139
+ # gold_label = matched_labels[result.agent_run_id]["label"]
140
+ # judge_label = result.output["label"]
141
+ # md = result.result_metadata or {}
142
+
143
+ # parsed_results.append(
144
+ # {
145
+ # "config_key": config_key,
146
+ # "agent_run_id": result.agent_run_id,
147
+ # "gold_label": gold_label,
148
+ # "judge_label": judge_label,
149
+ # "final_results": md.get("final_results"),
150
+ # "is_correct": gold_label == judge_label,
151
+ # }
152
+ # )
153
+
154
+ # df = pd.DataFrame(parsed_results)
155
+ # potential_labels = ["cheating", "future leakage", "not suspicious"]
156
+
157
+ # def _get_pred_dist(x: pd.Series):
158
+ # x = x.tolist()
159
+ # counts = [x.count(l) for l in potential_labels]
160
+ # assert np.sum(counts) == len(x)
161
+ # return {l: counts[i] / len(x) for i, l in enumerate(potential_labels)}
162
+
163
+ # n_ars = len(df.groupby("agent_run_id").count())
164
+ # p_correct = (
165
+ # df.groupby("agent_run_id")
166
+ # .agg(
167
+ # {
168
+ # "gold_label": lambda x: x.iloc[0],
169
+ # "judge_label": _get_pred_dist,
170
+ # "is_correct": np.mean,
171
+ # }
172
+ # )
173
+ # .rename(columns={"judge_label": "pred_dist", "is_correct": "p_correct_naive"})
174
+ # )
175
+ # p_correct["p_correct_naive_var"] = p_correct["p_correct_naive"].apply(lambda x: x * (1 - x))
176
+
177
+ # p_correct["p_correct_majority"] = p_correct.apply(
178
+ # lambda row: p_mode(
179
+ # maj_k,
180
+ # [row["pred_dist"][l] for l in potential_labels],
181
+ # potential_labels.index(row["gold_label"]),
182
+ # ),
183
+ # axis=1,
184
+ # )
185
+ # p_correct["p_correct_majority_var"] = p_correct["p_correct_majority"].apply(
186
+ # lambda x: x * (1 - x)
187
+ # )
188
+ # p_correct.sort_values(by="p_correct_majority_var", ascending=False, inplace=True)
189
+
190
+ # overall_naive_mean = p_correct["p_correct_naive"].mean()
191
+ # overall_naive_std = np.sqrt(p_correct["p_correct_naive_var"].sum() / n_ars**2)
192
+ # overall_majority_mean = p_correct["p_correct_majority"].mean()
193
+ # overall_majority_std = np.sqrt(p_correct["p_correct_majority_var"].sum() / n_ars**2)
194
+
195
+ # confidence_levels = [0.5, 0.95]
196
+ # print_stats_with_intervals(
197
+ # "Overall naive", overall_naive_mean, overall_naive_std, confidence_levels
198
+ # )
199
+ # print_stats_with_intervals(
200
+ # f"Overall majority (k={maj_k})",
201
+ # overall_majority_mean,
202
+ # overall_majority_std,
203
+ # confidence_levels,
204
+ # )
205
+ # return p_correct
docent/judges/types.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import enum
2
2
  import json
3
3
  from string import Formatter
4
- from typing import Any, Callable, Protocol
4
+ from typing import Any, Callable, Literal, Protocol
5
5
  from uuid import uuid4
6
6
 
7
7
  from pydantic import BaseModel, Field, field_serializer, field_validator
@@ -19,12 +19,64 @@ DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
19
19
  Here is a rubric that we are using to judge transcripts of AI agent runs.
20
20
 
21
21
  Rubric:
22
+ <rubric>
22
23
  {rubric}
24
+ </rubric>
23
25
 
24
26
  Agent run:
27
+ <agent_run>
25
28
  {agent_run}
29
+ </agent_run>
26
30
 
27
- Your response should convey your judgment of the agent run according to the criteria given in the rubric provided above. Your entire response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
31
+ Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step.
32
+
33
+ When you are finished, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
34
+
35
+ The JSON object you produce must adhere to the following schema:
36
+ {output_schema}
37
+
38
+ {citation_instructions}
39
+ """.strip()
40
+
41
+ DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
42
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
43
+
44
+ Rubric:
45
+ <rubric>
46
+ {rubric}
47
+ </rubric>
48
+
49
+ Agent run:
50
+ <agent_run>
51
+ {agent_run}
52
+ </agent_run>
53
+
54
+ Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must execute **one step in the decision procedure per assistant message turn**. After each turn, output a complete and detailed recount of all actions you took, and everything you discovered. Then call the `step_finished` tool.
55
+
56
+ When you are finished going through the decision procedure, output your final adjudication, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
57
+
58
+ The JSON object you produce must adhere to the following schema:
59
+ {output_schema}
60
+
61
+ {citation_instructions}
62
+ """.strip()
63
+
64
+ DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
65
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
66
+
67
+ Rubric:
68
+ <rubric>
69
+ {rubric}
70
+ </rubric>
71
+
72
+ Agent run:
73
+ <agent_run>
74
+ {agent_run}
75
+ </agent_run>
76
+
77
+ Your goal is to judge the agent run according to the criteria given in the rubric. Start by faithfully following the decision procedure in extremely careful detail, step by step. You must *fully externalize* your reasoning work by outputting details in the assistant message, surrounded by <reasoning>...</reasoning> tags. The reasoning section can be as messy as you need. You should use *high* reasoning effort.
78
+
79
+ When you are finished, output your final adjudication in the assistant message, surrounded by <response>...</response> tags. The response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
28
80
 
29
81
  The JSON object you produce must adhere to the following schema:
30
82
  {output_schema}
@@ -51,6 +103,11 @@ DEFAULT_JUDGE_OUTPUT_SCHEMA = {
51
103
  DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
52
104
 
53
105
 
106
+ class JudgeVariant(str, enum.Enum):
107
+ MAJORITY = "majority"
108
+ MULTI_REFLECT = "multi-reflect"
109
+
110
+
54
111
  class Rubric(BaseModel):
55
112
  """TODO(mengk): this should really be called JudgeConfig,
56
113
  but temporarily keeping this for consistency with docent_core."""
@@ -64,6 +121,11 @@ class Rubric(BaseModel):
64
121
 
65
122
  # What the judge actually does
66
123
  rubric_text: str
124
+ n_rollouts_per_input: int = 1
125
+ judge_variant: JudgeVariant = JudgeVariant.MAJORITY
126
+ # TODO(mengk): add this to the database
127
+ # No need right now because multi-turn is still very experimental.
128
+ rollout_type: Literal["single_turn", "multi_turn"] = "single_turn"
67
129
 
68
130
  # Default instructions for the judge
69
131
  system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
@@ -129,6 +191,15 @@ class Rubric(BaseModel):
129
191
  return output_schema
130
192
 
131
193
 
194
+ class MultiTurnRubric(Rubric):
195
+ system_prompt_template: str = DEFAULT_MULTI_TURN_JUDGE_SYSTEM_PROMPT_TEMPLATE
196
+ rollout_type: Literal["single_turn", "multi_turn"] = "multi_turn"
197
+
198
+
199
+ class ExposedReasoningRubric(Rubric):
200
+ system_prompt_template: str = DEFAULT_EXPOSED_REASONING_JUDGE_SYSTEM_PROMPT_TEMPLATE
201
+
202
+
132
203
  class ResultType(enum.Enum):
133
204
  """Enum for the type of result that a judge result can have."""
134
205
 
@@ -33,7 +33,9 @@
33
33
  },
34
34
  "enum": {
35
35
  "type": "array",
36
- "items": { "type": "string" }
36
+ "items": {
37
+ "type": ["string", "integer", "boolean"]
38
+ }
37
39
  },
38
40
  "format": {
39
41
  "type": "string",
@@ -1,10 +1,8 @@
1
- import json
2
1
  from typing import Any, cast
3
2
 
4
3
  import jsonschema
5
4
 
6
5
  from docent._llm_util.data_models.exceptions import ValidationFailedException
7
- from docent._llm_util.data_models.llm_output import LLMOutput
8
6
  from docent._log_util import get_logger
9
7
  from docent.data_models.agent_run import AgentRun
10
8
  from docent.data_models.remove_invalid_citation_ranges import remove_invalid_citation_ranges
@@ -55,10 +53,8 @@ def _validate_rubric_output(
55
53
  )
56
54
 
57
55
 
58
- def parse_and_validate_llm_output(
59
- llm_output: LLMOutput,
60
- output_schema: dict[str, Any],
61
- agent_run: AgentRun,
56
+ def parse_and_validate_output_str(
57
+ output_str: str, output_schema: dict[str, Any], agent_run: AgentRun
62
58
  ) -> dict[str, Any]:
63
59
  """Parse and validate LLM output for rubric evaluation.
64
60
 
@@ -73,23 +69,19 @@ def parse_and_validate_llm_output(
73
69
  Raises:
74
70
  ValidationFailedException: If parsing or validation fails
75
71
  """
76
- if llm_output.first_text is None:
77
- raise ValidationFailedException("LLM output has no text", failed_output=None)
78
72
 
79
73
  try:
80
- output = forgiving_json_loads(llm_output.first_text)
81
- except json.JSONDecodeError as e:
74
+ output = forgiving_json_loads(output_str)
75
+ except Exception as e:
82
76
  raise ValidationFailedException(
83
- f"Failed to parse JSON: {e}. Raw text: `{llm_output.first_text}`",
84
- failed_output=llm_output.first_text,
77
+ f"Failed to parse JSON: {e}. Raw text: `{output_str}`",
78
+ failed_output=output_str,
85
79
  )
86
80
 
87
81
  if not isinstance(output, dict):
88
- logger.error(f"Expected dict output, got {type(output)}")
89
- logger.error(f"LLM output: {llm_output.first_text}")
90
82
  raise ValidationFailedException(
91
- f"Expected dict output, got {type(output)}. Raw text: {llm_output.first_text}",
92
- failed_output=llm_output.first_text,
83
+ f"Expected dict output, got {type(output)}. Raw text: {output_str}",
84
+ failed_output=output_str,
93
85
  )
94
86
 
95
87
  return _validate_rubric_output(cast(dict[str, Any], output), output_schema, agent_run)
@@ -1,11 +1,23 @@
1
1
  from collections import Counter
2
- from typing import Any, cast
2
+ from typing import Any, TypedDict, cast
3
+
4
+ import numpy as np
5
+
6
+
7
+ class EstimateWithCI(TypedDict):
8
+ mean: float
9
+ var: float
10
+ n: int
11
+ ci_95: float
12
+
13
+
14
+ JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
3
15
 
4
16
 
5
17
  def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
6
18
  """Get list of top-level keys in schema that we want to measure agreement on.
7
19
 
8
- This includes enum, bool, and int fields. We skip float and strings.
20
+ This includes enum and bool fields.
9
21
 
10
22
  Args:
11
23
  schema: JSON schema dict
@@ -29,10 +41,7 @@ def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
29
41
  # Include boolean fields
30
42
  if field_type == "boolean":
31
43
  agreement_keys.append(key)
32
- # Include integer fields
33
- elif field_type == "integer":
34
- agreement_keys.append(key)
35
- # Include enum fields (even strings)
44
+ # Include enum fields (strings and numbers must be in this category)
36
45
  elif "enum" in field_schema:
37
46
  agreement_keys.append(key)
38
47
 
@@ -82,3 +91,49 @@ def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[
82
91
  max_idx = indep_result_scores.index(max(indep_result_scores))
83
92
 
84
93
  return max_idx, agt_key_modes_and_counts
94
+
95
+
96
+ def compute_output_distributions(
97
+ indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
98
+ ):
99
+ def _get_possible_values(key: str) -> list[str | bool | int | float]:
100
+ if "enum" in output_schema.get("properties", {}).get(key, {}):
101
+ return output_schema.get("properties", {}).get(key, {}).get("enum", [])
102
+ elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
103
+ return [True, False]
104
+ else:
105
+ return []
106
+
107
+ raw_counts: dict[str, dict[str | bool | int | float, int]] = {
108
+ key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
109
+ }
110
+ # Collect counts for each possible value
111
+ for result in indep_results:
112
+ for key in agreement_keys:
113
+ if (value := result.get(key)) is not None: # Could be none if the key is optional
114
+ assert (
115
+ value in raw_counts[key]
116
+ ), "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
117
+ raw_counts[key][value] += 1
118
+
119
+ distributions: dict[str, JudgeOutputDistribution] = {}
120
+ for agt_key in agreement_keys:
121
+ distributions[agt_key] = {}
122
+
123
+ # First normalize the counts to get probabilities
124
+ counts = raw_counts[agt_key]
125
+ total = sum(counts.values())
126
+ probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
127
+
128
+ for output_key, value in probs.items():
129
+ mean, estimate_var = value, (value * (1 - value))
130
+ # TODO(mengk): change to the wilson score interval
131
+ ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
132
+ distributions[agt_key][output_key] = {
133
+ "mean": mean,
134
+ "var": estimate_var,
135
+ "n": total,
136
+ "ci_95": ci_95,
137
+ }
138
+
139
+ return distributions