docent-python 0.1.19a0__py3-none-any.whl → 0.1.21a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +320 -0
- docent/_llm_util/data_models/simple_svc.py +79 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/model_registry.py +126 -0
- docent/_llm_util/prod_llms.py +454 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/transcript.py +2 -0
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +21 -0
- docent/judges/impl.py +222 -0
- docent/judges/types.py +240 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +84 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +95 -0
- docent/judges/util/voting.py +84 -0
- docent/sdk/client.py +5 -2
- docent/trace.py +1 -1
- docent/trace_2.py +1842 -0
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/METADATA +10 -5
- docent_python-0.1.21a0.dist-info/RECORD +58 -0
- docent_python-0.1.19a0.dist-info/RECORD +0 -32
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Iterable, List, TypeVar
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from docent.data_models.agent_run import AgentRun
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T", bound=BaseModel)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _deep_copy_model(model: T) -> T:
|
|
14
|
+
"""Create a deep copy of a Pydantic v2 model.
|
|
15
|
+
|
|
16
|
+
Using `model_copy(deep=True)` ensures nested models are fully copied and
|
|
17
|
+
mutations do not affect the original instance.
|
|
18
|
+
"""
|
|
19
|
+
return model.model_copy(deep=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def clone_agent_run_with_random_ids(agent_run: AgentRun) -> AgentRun:
|
|
23
|
+
"""Clone an `AgentRun`, randomizing all IDs and fixing internal references.
|
|
24
|
+
|
|
25
|
+
The following transformations are performed on the cloned instance:
|
|
26
|
+
- Assign a new `AgentRun.id`.
|
|
27
|
+
- Assign new `Transcript.id` values and update any references to them (none today).
|
|
28
|
+
- Assign new `TranscriptGroup.id` values.
|
|
29
|
+
- Update `Transcript.transcript_group_id` to the new group IDs where applicable.
|
|
30
|
+
- Update `TranscriptGroup.agent_run_id` to the new `AgentRun.id`.
|
|
31
|
+
- Update `TranscriptGroup.parent_transcript_group_id` to the new group IDs where applicable.
|
|
32
|
+
|
|
33
|
+
Notes:
|
|
34
|
+
- If a `parent_transcript_group_id` or `transcript_group_id` references a group id that
|
|
35
|
+
is not present in the cloned run, the reference is left unchanged (mirrors importer behavior).
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
agent_run: The source `AgentRun` to clone.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A new, independent `AgentRun` instance with randomized identifiers and consistent references.
|
|
42
|
+
"""
|
|
43
|
+
# Validate source integrity before cloning
|
|
44
|
+
# - No duplicate transcript or group IDs
|
|
45
|
+
# - All transcript.group references exist if set
|
|
46
|
+
# - All group.parent references exist if set
|
|
47
|
+
# - All group.agent_run_id match the source run id
|
|
48
|
+
src_transcript_ids = [str(t.id) for t in agent_run.transcripts]
|
|
49
|
+
if len(src_transcript_ids) != len(set(src_transcript_ids)):
|
|
50
|
+
raise ValueError("Duplicate transcript ids detected in source AgentRun")
|
|
51
|
+
|
|
52
|
+
src_group_ids = [str(g.id) for g in agent_run.transcript_groups]
|
|
53
|
+
if len(src_group_ids) != len(set(src_group_ids)):
|
|
54
|
+
raise ValueError("Duplicate transcript group ids detected in source AgentRun")
|
|
55
|
+
|
|
56
|
+
src_group_id_set = set(src_group_ids)
|
|
57
|
+
for t in agent_run.transcripts:
|
|
58
|
+
if t.transcript_group_id is not None and str(t.transcript_group_id) not in src_group_id_set:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Transcript {t.id} references missing transcript_group_id {t.transcript_group_id}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
for g in agent_run.transcript_groups:
|
|
64
|
+
if (
|
|
65
|
+
g.parent_transcript_group_id is not None
|
|
66
|
+
and str(g.parent_transcript_group_id) not in src_group_id_set
|
|
67
|
+
):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"TranscriptGroup {g.id} references missing parent_transcript_group_id {g.parent_transcript_group_id}"
|
|
70
|
+
)
|
|
71
|
+
if str(g.agent_run_id) != str(agent_run.id):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"TranscriptGroup {g.id} has agent_run_id {g.agent_run_id} which does not match AgentRun.id {agent_run.id}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Deep copy first so we never mutate the caller's instance
|
|
77
|
+
new_run = _deep_copy_model(agent_run)
|
|
78
|
+
|
|
79
|
+
# 1) Randomize AgentRun ID
|
|
80
|
+
new_agent_run_id = str(uuid4())
|
|
81
|
+
old_to_new_transcript_id: Dict[str, str] = {}
|
|
82
|
+
old_to_new_group_id: Dict[str, str] = {}
|
|
83
|
+
|
|
84
|
+
# 2) Pre-compute new IDs for transcripts and transcript groups without mutating yet
|
|
85
|
+
for transcript in new_run.transcripts:
|
|
86
|
+
old_to_new_transcript_id[str(transcript.id)] = str(uuid4())
|
|
87
|
+
|
|
88
|
+
for group in new_run.transcript_groups:
|
|
89
|
+
old_to_new_group_id[str(group.id)] = str(uuid4())
|
|
90
|
+
|
|
91
|
+
# 3) Mutate transcript groups: set new id, set agent_run_id, remap parents
|
|
92
|
+
for group in new_run.transcript_groups:
|
|
93
|
+
old_group_id = str(group.id)
|
|
94
|
+
|
|
95
|
+
# Assign new group id
|
|
96
|
+
group.id = old_to_new_group_id.get(old_group_id, str(uuid4()))
|
|
97
|
+
|
|
98
|
+
# Ensure group points to the new agent run id
|
|
99
|
+
group.agent_run_id = new_agent_run_id
|
|
100
|
+
|
|
101
|
+
# Remap parent id; raise if unknown
|
|
102
|
+
if group.parent_transcript_group_id is not None:
|
|
103
|
+
old_parent_id = str(group.parent_transcript_group_id)
|
|
104
|
+
if old_parent_id not in old_to_new_group_id:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"TranscriptGroup {old_group_id} parent_transcript_group_id {old_parent_id} not found in this AgentRun"
|
|
107
|
+
)
|
|
108
|
+
group.parent_transcript_group_id = old_to_new_group_id[old_parent_id]
|
|
109
|
+
|
|
110
|
+
# 4) Mutate transcripts: set new id, remap transcript_group_id
|
|
111
|
+
for transcript in new_run.transcripts:
|
|
112
|
+
old_transcript_id = str(transcript.id)
|
|
113
|
+
|
|
114
|
+
# Assign new transcript id
|
|
115
|
+
transcript.id = old_to_new_transcript_id.get(old_transcript_id, str(uuid4()))
|
|
116
|
+
|
|
117
|
+
# Remap group reference; raise if unknown
|
|
118
|
+
if transcript.transcript_group_id is not None:
|
|
119
|
+
old_group_id_ref = str(transcript.transcript_group_id)
|
|
120
|
+
if old_group_id_ref not in old_to_new_group_id:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"Transcript {old_transcript_id} references transcript_group_id {old_group_id_ref} not found in this AgentRun"
|
|
123
|
+
)
|
|
124
|
+
transcript.transcript_group_id = old_to_new_group_id[old_group_id_ref]
|
|
125
|
+
|
|
126
|
+
# 5) Finally set the new run id
|
|
127
|
+
new_run.id = new_agent_run_id
|
|
128
|
+
|
|
129
|
+
# Post-validate integrity on the cloned run
|
|
130
|
+
new_group_ids = [str(g.id) for g in new_run.transcript_groups]
|
|
131
|
+
if len(new_group_ids) != len(set(new_group_ids)):
|
|
132
|
+
raise ValueError("Duplicate transcript group ids detected after cloning")
|
|
133
|
+
new_group_id_set = set(new_group_ids)
|
|
134
|
+
|
|
135
|
+
new_transcript_ids = [str(t.id) for t in new_run.transcripts]
|
|
136
|
+
if len(new_transcript_ids) != len(set(new_transcript_ids)):
|
|
137
|
+
raise ValueError("Duplicate transcript ids detected after cloning")
|
|
138
|
+
|
|
139
|
+
for t in new_run.transcripts:
|
|
140
|
+
if t.transcript_group_id is not None and str(t.transcript_group_id) not in new_group_id_set:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Transcript {t.id} references missing transcript_group_id {t.transcript_group_id} after cloning"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
for g in new_run.transcript_groups:
|
|
146
|
+
if (
|
|
147
|
+
g.parent_transcript_group_id is not None
|
|
148
|
+
and str(g.parent_transcript_group_id) not in new_group_id_set
|
|
149
|
+
):
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"TranscriptGroup {g.id} references missing parent_transcript_group_id {g.parent_transcript_group_id} after cloning"
|
|
152
|
+
)
|
|
153
|
+
if str(g.agent_run_id) != str(new_run.id):
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"TranscriptGroup {g.id} has agent_run_id {g.agent_run_id} which does not match cloned AgentRun.id {new_run.id}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return new_run
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def clone_agent_runs_with_random_ids(agent_runs: Iterable[AgentRun]) -> List[AgentRun]:
|
|
162
|
+
"""Clone a sequence of `AgentRun` objects with randomized IDs.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
agent_runs: Iterable of `AgentRun` instances to clone.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A list of cloned `AgentRun` instances with fresh IDs and consistent references.
|
|
169
|
+
"""
|
|
170
|
+
return [clone_agent_run_with_random_ids(ar) for ar in agent_runs]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from docent.judges.impl import BaseJudge, MajorityVotingJudge, MultiReflectionJudge
|
|
2
|
+
from docent.judges.types import (
|
|
3
|
+
JudgeResult,
|
|
4
|
+
JudgeResultCompletionCallback,
|
|
5
|
+
JudgeResultWithCitations,
|
|
6
|
+
ResultType,
|
|
7
|
+
Rubric,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
# Judges
|
|
12
|
+
"MajorityVotingJudge",
|
|
13
|
+
"MultiReflectionJudge",
|
|
14
|
+
"BaseJudge",
|
|
15
|
+
# Types
|
|
16
|
+
"Rubric",
|
|
17
|
+
"JudgeResult",
|
|
18
|
+
"JudgeResultWithCitations",
|
|
19
|
+
"JudgeResultCompletionCallback",
|
|
20
|
+
"ResultType",
|
|
21
|
+
]
|
docent/judges/impl.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from docent._llm_util.data_models.llm_output import LLMOutput
|
|
6
|
+
from docent._llm_util.data_models.simple_svc import BaseLLMService, SimpleLLMService
|
|
7
|
+
from docent._log_util import get_logger
|
|
8
|
+
from docent.data_models.agent_run import AgentRun
|
|
9
|
+
from docent.judges.types import JudgeResult, ResultType, Rubric
|
|
10
|
+
from docent.judges.util.parse_output import parse_and_validate_llm_output
|
|
11
|
+
from docent.judges.util.voting import find_modal_result, get_agreement_keys
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseJudge(ABC):
|
|
17
|
+
def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
|
|
18
|
+
self.cfg = cfg
|
|
19
|
+
self.llm_svc = llm_svc
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def __call__(self, agent_run: AgentRun, *args: Any, **kwargs: Any) -> JudgeResult | None:
|
|
23
|
+
"""Returns None if all rollouts failed to produce a valid output."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MajorityVotingJudge(BaseJudge):
|
|
27
|
+
"""Rolls out the judge multiple times, then uses majority voting to determine the final result."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
cfg: Rubric,
|
|
32
|
+
n_rollouts_per_input: int,
|
|
33
|
+
llm_svc: BaseLLMService = SimpleLLMService(),
|
|
34
|
+
):
|
|
35
|
+
super().__init__(cfg, llm_svc)
|
|
36
|
+
self.n_rollouts_per_input = n_rollouts_per_input
|
|
37
|
+
|
|
38
|
+
async def __call__(
|
|
39
|
+
self,
|
|
40
|
+
agent_run: AgentRun,
|
|
41
|
+
max_concurrency: int = 10,
|
|
42
|
+
) -> JudgeResult | None:
|
|
43
|
+
async def _validation_callback(batch_index: int, llm_output: LLMOutput):
|
|
44
|
+
parse_and_validate_llm_output(llm_output, self.cfg.output_schema, agent_run)
|
|
45
|
+
|
|
46
|
+
prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
|
|
47
|
+
outputs = await self.llm_svc.get_completions(
|
|
48
|
+
inputs=[prompt for _ in range(self.n_rollouts_per_input)],
|
|
49
|
+
model_options=[self.cfg.judge_model],
|
|
50
|
+
max_new_tokens=16384,
|
|
51
|
+
timeout=180.0,
|
|
52
|
+
use_cache=False,
|
|
53
|
+
validation_callback=_validation_callback,
|
|
54
|
+
max_concurrency=max_concurrency,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Process each rollout independently
|
|
58
|
+
indep_results: list[dict[str, Any]] = []
|
|
59
|
+
for output in outputs:
|
|
60
|
+
if validated_output := parse_and_validate_llm_output(
|
|
61
|
+
output, self.cfg.output_schema, agent_run
|
|
62
|
+
):
|
|
63
|
+
indep_results.append(validated_output)
|
|
64
|
+
|
|
65
|
+
if not indep_results:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
# Get a list of the keys that we want to measure agreement on
|
|
69
|
+
agreement_keys = get_agreement_keys(self.cfg.output_schema)
|
|
70
|
+
|
|
71
|
+
# Find the result that best matches modal values
|
|
72
|
+
final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
|
|
73
|
+
indep_results, agreement_keys
|
|
74
|
+
)
|
|
75
|
+
final_output = indep_results[final_max_idx]
|
|
76
|
+
|
|
77
|
+
return JudgeResult(
|
|
78
|
+
agent_run_id=agent_run.id,
|
|
79
|
+
rubric_id=self.cfg.id,
|
|
80
|
+
rubric_version=self.cfg.version,
|
|
81
|
+
output=final_output,
|
|
82
|
+
result_metadata={
|
|
83
|
+
"agt_keys": agreement_keys,
|
|
84
|
+
# Final measurements
|
|
85
|
+
"final_results": indep_results,
|
|
86
|
+
"final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
|
|
87
|
+
"final_max_idx": final_max_idx,
|
|
88
|
+
},
|
|
89
|
+
result_type=ResultType.DIRECT_RESULT,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class MultiReflectionJudge(BaseJudge):
|
|
94
|
+
"""Rolls out the judge multiple times, then uses reflection to determine the final result."""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
cfg: Rubric,
|
|
99
|
+
n_rollouts_per_input: int,
|
|
100
|
+
llm_svc: BaseLLMService = SimpleLLMService(),
|
|
101
|
+
):
|
|
102
|
+
super().__init__(cfg, llm_svc)
|
|
103
|
+
self.n_rollouts_per_input = n_rollouts_per_input
|
|
104
|
+
|
|
105
|
+
async def __call__(
|
|
106
|
+
self,
|
|
107
|
+
agent_run: AgentRun,
|
|
108
|
+
max_concurrency: int = 10,
|
|
109
|
+
) -> JudgeResult | None:
|
|
110
|
+
rubric = self.cfg
|
|
111
|
+
|
|
112
|
+
async def _validation_callback(batch_index: int, llm_output: LLMOutput):
|
|
113
|
+
parse_and_validate_llm_output(llm_output, rubric.output_schema, agent_run)
|
|
114
|
+
|
|
115
|
+
# Run several independent rollouts
|
|
116
|
+
prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
|
|
117
|
+
outputs = await self.llm_svc.get_completions(
|
|
118
|
+
inputs=[prompt for _ in range(self.n_rollouts_per_input)],
|
|
119
|
+
model_options=[rubric.judge_model],
|
|
120
|
+
max_new_tokens=16384,
|
|
121
|
+
timeout=180.0,
|
|
122
|
+
use_cache=False,
|
|
123
|
+
validation_callback=_validation_callback,
|
|
124
|
+
max_concurrency=max_concurrency,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Process each rollout
|
|
128
|
+
indep_results: list[dict[str, Any]] = []
|
|
129
|
+
for output in outputs:
|
|
130
|
+
if output.first_text is None:
|
|
131
|
+
continue
|
|
132
|
+
if v_output := parse_and_validate_llm_output(output, rubric.output_schema, agent_run):
|
|
133
|
+
indep_results.append(v_output)
|
|
134
|
+
|
|
135
|
+
if not indep_results:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
# Compute initial modes
|
|
139
|
+
agreement_keys = get_agreement_keys(rubric.output_schema)
|
|
140
|
+
indep_max_idx, indep_agt_key_modes_and_counts = find_modal_result(
|
|
141
|
+
indep_results, agreement_keys
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _get_reflection_prompt(cur_index: int):
|
|
145
|
+
# Current result
|
|
146
|
+
result = indep_results[cur_index]
|
|
147
|
+
# Get other results (excluding the current one)
|
|
148
|
+
other_results = [r for j, r in enumerate(indep_results) if j != cur_index]
|
|
149
|
+
|
|
150
|
+
# Create the reflection message
|
|
151
|
+
other_results_text = "\n\n".join(
|
|
152
|
+
[f"Answer {j+1}:\n{json.dumps(r, indent=2)}" for j, r in enumerate(other_results)]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
reflection_instruction = (
|
|
156
|
+
f"Here are {len(other_results)} other independent answers to the same rubric evaluation:\n\n"
|
|
157
|
+
f"{other_results_text}\n\n"
|
|
158
|
+
f"Please reflect on these other answers and your own answer. "
|
|
159
|
+
f"Consider if any of them have identified important aspects you missed, or if there are disagreements that should be resolved. "
|
|
160
|
+
f"Then provide your final answer in the same JSON format as before."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Construct the multi-message prompt
|
|
164
|
+
# 1. Original user message
|
|
165
|
+
# 2. Assistant message with the rollout's result
|
|
166
|
+
# 3. New user message asking for reflection
|
|
167
|
+
return [
|
|
168
|
+
*prompt, # Original user message(s)
|
|
169
|
+
{"role": "assistant", "content": json.dumps(result, indent=2)},
|
|
170
|
+
{"role": "user", "content": reflection_instruction},
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
final_results = indep_results.copy() # Shallow copy
|
|
174
|
+
if len(indep_results) > 1:
|
|
175
|
+
# Ask the judge to reflect on the others' results
|
|
176
|
+
reflection_outputs = await self.llm_svc.get_completions(
|
|
177
|
+
inputs=[_get_reflection_prompt(i) for i in range(len(indep_results))],
|
|
178
|
+
model_options=[rubric.judge_model],
|
|
179
|
+
max_new_tokens=16384,
|
|
180
|
+
timeout=180.0,
|
|
181
|
+
use_cache=False,
|
|
182
|
+
validation_callback=_validation_callback,
|
|
183
|
+
max_concurrency=max_concurrency,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Process reflection outputs in the same way as the initial rollouts
|
|
187
|
+
reflected_results: list[dict[str, Any]] = []
|
|
188
|
+
for output in reflection_outputs:
|
|
189
|
+
if output.first_text is None:
|
|
190
|
+
continue
|
|
191
|
+
if v_output := parse_and_validate_llm_output(
|
|
192
|
+
output, rubric.output_schema, agent_run
|
|
193
|
+
):
|
|
194
|
+
reflected_results.append(v_output)
|
|
195
|
+
|
|
196
|
+
# Use reflected results if we got any, otherwise fall back to original results
|
|
197
|
+
if reflected_results:
|
|
198
|
+
final_results = reflected_results
|
|
199
|
+
else:
|
|
200
|
+
logger.warning("No reflected results found, falling back to original results")
|
|
201
|
+
|
|
202
|
+
final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
|
|
203
|
+
final_results, agreement_keys
|
|
204
|
+
)
|
|
205
|
+
return JudgeResult(
|
|
206
|
+
agent_run_id=agent_run.id,
|
|
207
|
+
rubric_id=rubric.id,
|
|
208
|
+
rubric_version=rubric.version,
|
|
209
|
+
output=final_results[final_max_idx],
|
|
210
|
+
result_metadata={
|
|
211
|
+
"agt_keys": agreement_keys,
|
|
212
|
+
# Final measurements
|
|
213
|
+
"final_results": final_results,
|
|
214
|
+
"final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
|
|
215
|
+
"final_max_idx": final_max_idx,
|
|
216
|
+
# Also include initial measurements
|
|
217
|
+
"indep_results": indep_results,
|
|
218
|
+
"indep_max_idx": indep_max_idx,
|
|
219
|
+
"indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
|
|
220
|
+
},
|
|
221
|
+
result_type=ResultType.DIRECT_RESULT,
|
|
222
|
+
)
|
docent/judges/types.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import json
|
|
3
|
+
from string import Formatter
|
|
4
|
+
from typing import Any, Callable, Protocol
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
8
|
+
|
|
9
|
+
from docent._llm_util.providers.preference_types import PUBLIC_PROVIDER_PREFERENCES, ModelOption
|
|
10
|
+
from docent._log_util import get_logger
|
|
11
|
+
from docent.data_models.agent_run import AgentRun
|
|
12
|
+
from docent.data_models.citation import parse_citations
|
|
13
|
+
from docent.data_models.transcript import TEXT_RANGE_CITE_INSTRUCTION
|
|
14
|
+
from docent.judges.util.meta_schema import validate_judge_result_schema
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
|
|
19
|
+
Here is a rubric that we are using to judge transcripts of AI agent runs.
|
|
20
|
+
|
|
21
|
+
Rubric:
|
|
22
|
+
{rubric}
|
|
23
|
+
|
|
24
|
+
Agent run:
|
|
25
|
+
{agent_run}
|
|
26
|
+
|
|
27
|
+
Your response should convey your judgment of the agent run according to the criteria given in the rubric provided above. Your entire response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
|
|
28
|
+
|
|
29
|
+
The JSON object you produce must adhere to the following schema:
|
|
30
|
+
{output_schema}
|
|
31
|
+
|
|
32
|
+
{citation_instructions}
|
|
33
|
+
""".strip()
|
|
34
|
+
|
|
35
|
+
DEFAULT_CITATION_INSTRUCTIONS = f"""
|
|
36
|
+
For strings which require citations (according to the `citations: True` property), you must also follow these instructions:
|
|
37
|
+
{TEXT_RANGE_CITE_INSTRUCTION}
|
|
38
|
+
""".strip()
|
|
39
|
+
|
|
40
|
+
DEFAULT_JUDGE_OUTPUT_SCHEMA = {
|
|
41
|
+
"type": "object",
|
|
42
|
+
"properties": {
|
|
43
|
+
"label": {"type": "string", "enum": ["match", "no match"]},
|
|
44
|
+
"explanation": {"type": "string", "citations": True},
|
|
45
|
+
},
|
|
46
|
+
# Require these properties to be present
|
|
47
|
+
"required": ["label", "explanation"],
|
|
48
|
+
# Allow additional properties though, as their presence is not breaking
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Rubric(BaseModel):
|
|
55
|
+
"""TODO(mengk): this should really be called JudgeConfig,
|
|
56
|
+
but temporarily keeping this for consistency with docent_core."""
|
|
57
|
+
|
|
58
|
+
class Config:
|
|
59
|
+
frozen = True
|
|
60
|
+
|
|
61
|
+
# Primary key
|
|
62
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
63
|
+
version: int = 1
|
|
64
|
+
|
|
65
|
+
# What the judge actually does
|
|
66
|
+
rubric_text: str
|
|
67
|
+
|
|
68
|
+
# Default instructions for the judge
|
|
69
|
+
system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
|
|
70
|
+
citation_instructions: str = DEFAULT_CITATION_INSTRUCTIONS
|
|
71
|
+
output_schema: dict[str, Any] = DEFAULT_JUDGE_OUTPUT_SCHEMA
|
|
72
|
+
|
|
73
|
+
# How to run the judge
|
|
74
|
+
judge_model: ModelOption = DEFAULT_JUDGE_MODEL
|
|
75
|
+
|
|
76
|
+
def materialize_system_prompt(self, agent_run: AgentRun) -> str:
|
|
77
|
+
"""Construct the full prompt text for rubric evaluation.
|
|
78
|
+
|
|
79
|
+
This is the canonical implementation of prompt construction - use this function
|
|
80
|
+
anywhere you need to construct a rubric evaluation prompt (including cost estimation).
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
output_schema_text = json.dumps(self.output_schema, indent=2)
|
|
84
|
+
|
|
85
|
+
# We've already validated that the system prompt template has these keys
|
|
86
|
+
prompt = self.system_prompt_template.format(
|
|
87
|
+
rubric=self.rubric_text,
|
|
88
|
+
agent_run=agent_run.to_text_new(),
|
|
89
|
+
output_schema=output_schema_text,
|
|
90
|
+
# Only include citation instructions if the schema requests citations
|
|
91
|
+
citation_instructions=(
|
|
92
|
+
self.citation_instructions if _schema_requests_citations(self.output_schema) else ""
|
|
93
|
+
),
|
|
94
|
+
).strip()
|
|
95
|
+
|
|
96
|
+
return prompt
|
|
97
|
+
|
|
98
|
+
@field_validator("system_prompt_template")
|
|
99
|
+
@classmethod
|
|
100
|
+
def validate_system_prompt_template(cls, system_prompt_template: str):
|
|
101
|
+
# Extract all field names from the template
|
|
102
|
+
formatter = Formatter()
|
|
103
|
+
field_names = {
|
|
104
|
+
field_name
|
|
105
|
+
for _, field_name, _, _ in formatter.parse(system_prompt_template)
|
|
106
|
+
if field_name is not None
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
# Check for required fields
|
|
110
|
+
required_fields = {"agent_run", "output_schema", "rubric", "citation_instructions"}
|
|
111
|
+
missing_fields = required_fields - field_names
|
|
112
|
+
|
|
113
|
+
if missing_fields:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"system_prompt_template must contain the following placeholders: {missing_fields}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return system_prompt_template
|
|
119
|
+
|
|
120
|
+
@field_validator("output_schema")
|
|
121
|
+
@classmethod
|
|
122
|
+
def validate_output_schema(cls, output_schema: dict[str, Any]):
|
|
123
|
+
"""
|
|
124
|
+
Raises:
|
|
125
|
+
jsonschema.ValidationError: If the schema is invalid
|
|
126
|
+
jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
|
|
127
|
+
"""
|
|
128
|
+
validate_judge_result_schema(output_schema)
|
|
129
|
+
return output_schema
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ResultType(enum.Enum):
|
|
133
|
+
"""Enum for the type of result that a judge result can have."""
|
|
134
|
+
|
|
135
|
+
DIRECT_RESULT = "direct_result"
|
|
136
|
+
NEAR_MISS = "near_miss"
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class JudgeResult(BaseModel):
|
|
140
|
+
class Config:
|
|
141
|
+
frozen = True
|
|
142
|
+
|
|
143
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
144
|
+
agent_run_id: str
|
|
145
|
+
rubric_id: str
|
|
146
|
+
rubric_version: int
|
|
147
|
+
|
|
148
|
+
# Outputs
|
|
149
|
+
output: dict[str, Any]
|
|
150
|
+
result_metadata: dict[str, Any] | None = None
|
|
151
|
+
result_type: ResultType
|
|
152
|
+
|
|
153
|
+
# Deprecated
|
|
154
|
+
value: str | None = None
|
|
155
|
+
|
|
156
|
+
@field_serializer("result_type")
|
|
157
|
+
def serialize_result_type(self, result_type: ResultType) -> str:
|
|
158
|
+
return result_type.value
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class JudgeResultWithCitations(JudgeResult):
|
|
162
|
+
@classmethod
|
|
163
|
+
def from_judge_result(
|
|
164
|
+
cls, result: JudgeResult, schema: dict[str, Any]
|
|
165
|
+
) -> "JudgeResultWithCitations":
|
|
166
|
+
"""Judge result must be validated against the schema before calling this function!"""
|
|
167
|
+
|
|
168
|
+
def _parse_citation_string(output: str) -> dict[str, Any]:
|
|
169
|
+
text, citations = parse_citations(output)
|
|
170
|
+
return {"text": text, "citations": citations}
|
|
171
|
+
|
|
172
|
+
data = result.model_dump()
|
|
173
|
+
try:
|
|
174
|
+
data["output"] = traverse_schema_and_transform(
|
|
175
|
+
data["output"], schema, _parse_citation_string
|
|
176
|
+
)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Failed to parse citations: {e}")
|
|
179
|
+
logger.error(f"Output: {data['output']}")
|
|
180
|
+
data["output"] = {"raw": data["output"]}
|
|
181
|
+
return cls(**data)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class JudgeResultCompletionCallback(Protocol):
|
|
185
|
+
"""Called when some batch of judge results is completed.
|
|
186
|
+
Supports batched calls for cases where many results are pre-computed.
|
|
187
|
+
This avoids invoking the callback separately for each datapoint.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
async def __call__(
|
|
191
|
+
self,
|
|
192
|
+
batch_index: int,
|
|
193
|
+
judge_results: list[JudgeResult] | None,
|
|
194
|
+
) -> None: ...
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def traverse_schema_and_transform(
|
|
198
|
+
output: Any,
|
|
199
|
+
schema: dict[str, Any],
|
|
200
|
+
citation_string_handler: Callable[[str], Any],
|
|
201
|
+
) -> Any:
|
|
202
|
+
"""Recursively traverse output based on schema, applying citation_string_handler to citation strings."""
|
|
203
|
+
if schema.get("type") == "string" and schema.get("citations"): # type: ignore
|
|
204
|
+
return citation_string_handler(output)
|
|
205
|
+
elif schema.get("type") == "object":
|
|
206
|
+
properties: dict[str, Any] = schema.get("properties", {})
|
|
207
|
+
result: dict[str, Any] = {}
|
|
208
|
+
for key in properties:
|
|
209
|
+
if key in output:
|
|
210
|
+
result[key] = traverse_schema_and_transform(
|
|
211
|
+
output[key], properties[key], citation_string_handler
|
|
212
|
+
)
|
|
213
|
+
return result
|
|
214
|
+
elif schema.get("type") == "array":
|
|
215
|
+
item_schema: dict[str, Any] = schema.get("items", {})
|
|
216
|
+
return [
|
|
217
|
+
traverse_schema_and_transform(item, item_schema, citation_string_handler)
|
|
218
|
+
for item in output
|
|
219
|
+
]
|
|
220
|
+
else:
|
|
221
|
+
return output
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _schema_requests_citations(schema: dict[str, Any]) -> bool:
|
|
225
|
+
"""Check if any field in the schema requests citations by having 'citations': 'true'."""
|
|
226
|
+
|
|
227
|
+
def _check_field(field_schema: Any) -> bool:
|
|
228
|
+
if isinstance(field_schema, dict):
|
|
229
|
+
if field_schema.get("citations"): # type: ignore
|
|
230
|
+
return True
|
|
231
|
+
for value in field_schema.values(): # type: ignore
|
|
232
|
+
if isinstance(value, dict) and _check_field(value):
|
|
233
|
+
return True
|
|
234
|
+
elif isinstance(value, list):
|
|
235
|
+
for item in value: # type: ignore
|
|
236
|
+
if isinstance(item, dict) and _check_field(item):
|
|
237
|
+
return True
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
return _check_field(schema)
|