docent-python 0.1.19a0__py3-none-any.whl → 0.1.21a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (34) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +320 -0
  5. docent/_llm_util/data_models/simple_svc.py +79 -0
  6. docent/_llm_util/llm_cache.py +193 -0
  7. docent/_llm_util/model_registry.py +126 -0
  8. docent/_llm_util/prod_llms.py +454 -0
  9. docent/_llm_util/providers/__init__.py +0 -0
  10. docent/_llm_util/providers/anthropic.py +537 -0
  11. docent/_llm_util/providers/common.py +41 -0
  12. docent/_llm_util/providers/google.py +530 -0
  13. docent/_llm_util/providers/openai.py +745 -0
  14. docent/_llm_util/providers/openrouter.py +375 -0
  15. docent/_llm_util/providers/preference_types.py +104 -0
  16. docent/_llm_util/providers/provider_registry.py +164 -0
  17. docent/data_models/transcript.py +2 -0
  18. docent/data_models/util.py +170 -0
  19. docent/judges/__init__.py +21 -0
  20. docent/judges/impl.py +222 -0
  21. docent/judges/types.py +240 -0
  22. docent/judges/util/forgiving_json.py +108 -0
  23. docent/judges/util/meta_schema.json +84 -0
  24. docent/judges/util/meta_schema.py +29 -0
  25. docent/judges/util/parse_output.py +95 -0
  26. docent/judges/util/voting.py +84 -0
  27. docent/sdk/client.py +5 -2
  28. docent/trace.py +1 -1
  29. docent/trace_2.py +1842 -0
  30. {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/METADATA +10 -5
  31. docent_python-0.1.21a0.dist-info/RECORD +58 -0
  32. docent_python-0.1.19a0.dist-info/RECORD +0 -32
  33. {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/WHEEL +0 -0
  34. {docent_python-0.1.19a0.dist-info → docent_python-0.1.21a0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Iterable, List, TypeVar
4
+ from uuid import uuid4
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from docent.data_models.agent_run import AgentRun
9
+
10
+ T = TypeVar("T", bound=BaseModel)
11
+
12
+
13
+ def _deep_copy_model(model: T) -> T:
14
+ """Create a deep copy of a Pydantic v2 model.
15
+
16
+ Using `model_copy(deep=True)` ensures nested models are fully copied and
17
+ mutations do not affect the original instance.
18
+ """
19
+ return model.model_copy(deep=True)
20
+
21
+
22
+ def clone_agent_run_with_random_ids(agent_run: AgentRun) -> AgentRun:
23
+ """Clone an `AgentRun`, randomizing all IDs and fixing internal references.
24
+
25
+ The following transformations are performed on the cloned instance:
26
+ - Assign a new `AgentRun.id`.
27
+ - Assign new `Transcript.id` values and update any references to them (none today).
28
+ - Assign new `TranscriptGroup.id` values.
29
+ - Update `Transcript.transcript_group_id` to the new group IDs where applicable.
30
+ - Update `TranscriptGroup.agent_run_id` to the new `AgentRun.id`.
31
+ - Update `TranscriptGroup.parent_transcript_group_id` to the new group IDs where applicable.
32
+
33
+ Notes:
34
+ - If a `parent_transcript_group_id` or `transcript_group_id` references a group id that
35
+ is not present in the cloned run, the reference is left unchanged (mirrors importer behavior).
36
+
37
+ Args:
38
+ agent_run: The source `AgentRun` to clone.
39
+
40
+ Returns:
41
+ A new, independent `AgentRun` instance with randomized identifiers and consistent references.
42
+ """
43
+ # Validate source integrity before cloning
44
+ # - No duplicate transcript or group IDs
45
+ # - All transcript.group references exist if set
46
+ # - All group.parent references exist if set
47
+ # - All group.agent_run_id match the source run id
48
+ src_transcript_ids = [str(t.id) for t in agent_run.transcripts]
49
+ if len(src_transcript_ids) != len(set(src_transcript_ids)):
50
+ raise ValueError("Duplicate transcript ids detected in source AgentRun")
51
+
52
+ src_group_ids = [str(g.id) for g in agent_run.transcript_groups]
53
+ if len(src_group_ids) != len(set(src_group_ids)):
54
+ raise ValueError("Duplicate transcript group ids detected in source AgentRun")
55
+
56
+ src_group_id_set = set(src_group_ids)
57
+ for t in agent_run.transcripts:
58
+ if t.transcript_group_id is not None and str(t.transcript_group_id) not in src_group_id_set:
59
+ raise ValueError(
60
+ f"Transcript {t.id} references missing transcript_group_id {t.transcript_group_id}"
61
+ )
62
+
63
+ for g in agent_run.transcript_groups:
64
+ if (
65
+ g.parent_transcript_group_id is not None
66
+ and str(g.parent_transcript_group_id) not in src_group_id_set
67
+ ):
68
+ raise ValueError(
69
+ f"TranscriptGroup {g.id} references missing parent_transcript_group_id {g.parent_transcript_group_id}"
70
+ )
71
+ if str(g.agent_run_id) != str(agent_run.id):
72
+ raise ValueError(
73
+ f"TranscriptGroup {g.id} has agent_run_id {g.agent_run_id} which does not match AgentRun.id {agent_run.id}"
74
+ )
75
+
76
+ # Deep copy first so we never mutate the caller's instance
77
+ new_run = _deep_copy_model(agent_run)
78
+
79
+ # 1) Randomize AgentRun ID
80
+ new_agent_run_id = str(uuid4())
81
+ old_to_new_transcript_id: Dict[str, str] = {}
82
+ old_to_new_group_id: Dict[str, str] = {}
83
+
84
+ # 2) Pre-compute new IDs for transcripts and transcript groups without mutating yet
85
+ for transcript in new_run.transcripts:
86
+ old_to_new_transcript_id[str(transcript.id)] = str(uuid4())
87
+
88
+ for group in new_run.transcript_groups:
89
+ old_to_new_group_id[str(group.id)] = str(uuid4())
90
+
91
+ # 3) Mutate transcript groups: set new id, set agent_run_id, remap parents
92
+ for group in new_run.transcript_groups:
93
+ old_group_id = str(group.id)
94
+
95
+ # Assign new group id
96
+ group.id = old_to_new_group_id.get(old_group_id, str(uuid4()))
97
+
98
+ # Ensure group points to the new agent run id
99
+ group.agent_run_id = new_agent_run_id
100
+
101
+ # Remap parent id; raise if unknown
102
+ if group.parent_transcript_group_id is not None:
103
+ old_parent_id = str(group.parent_transcript_group_id)
104
+ if old_parent_id not in old_to_new_group_id:
105
+ raise ValueError(
106
+ f"TranscriptGroup {old_group_id} parent_transcript_group_id {old_parent_id} not found in this AgentRun"
107
+ )
108
+ group.parent_transcript_group_id = old_to_new_group_id[old_parent_id]
109
+
110
+ # 4) Mutate transcripts: set new id, remap transcript_group_id
111
+ for transcript in new_run.transcripts:
112
+ old_transcript_id = str(transcript.id)
113
+
114
+ # Assign new transcript id
115
+ transcript.id = old_to_new_transcript_id.get(old_transcript_id, str(uuid4()))
116
+
117
+ # Remap group reference; raise if unknown
118
+ if transcript.transcript_group_id is not None:
119
+ old_group_id_ref = str(transcript.transcript_group_id)
120
+ if old_group_id_ref not in old_to_new_group_id:
121
+ raise ValueError(
122
+ f"Transcript {old_transcript_id} references transcript_group_id {old_group_id_ref} not found in this AgentRun"
123
+ )
124
+ transcript.transcript_group_id = old_to_new_group_id[old_group_id_ref]
125
+
126
+ # 5) Finally set the new run id
127
+ new_run.id = new_agent_run_id
128
+
129
+ # Post-validate integrity on the cloned run
130
+ new_group_ids = [str(g.id) for g in new_run.transcript_groups]
131
+ if len(new_group_ids) != len(set(new_group_ids)):
132
+ raise ValueError("Duplicate transcript group ids detected after cloning")
133
+ new_group_id_set = set(new_group_ids)
134
+
135
+ new_transcript_ids = [str(t.id) for t in new_run.transcripts]
136
+ if len(new_transcript_ids) != len(set(new_transcript_ids)):
137
+ raise ValueError("Duplicate transcript ids detected after cloning")
138
+
139
+ for t in new_run.transcripts:
140
+ if t.transcript_group_id is not None and str(t.transcript_group_id) not in new_group_id_set:
141
+ raise ValueError(
142
+ f"Transcript {t.id} references missing transcript_group_id {t.transcript_group_id} after cloning"
143
+ )
144
+
145
+ for g in new_run.transcript_groups:
146
+ if (
147
+ g.parent_transcript_group_id is not None
148
+ and str(g.parent_transcript_group_id) not in new_group_id_set
149
+ ):
150
+ raise ValueError(
151
+ f"TranscriptGroup {g.id} references missing parent_transcript_group_id {g.parent_transcript_group_id} after cloning"
152
+ )
153
+ if str(g.agent_run_id) != str(new_run.id):
154
+ raise ValueError(
155
+ f"TranscriptGroup {g.id} has agent_run_id {g.agent_run_id} which does not match cloned AgentRun.id {new_run.id}"
156
+ )
157
+
158
+ return new_run
159
+
160
+
161
+ def clone_agent_runs_with_random_ids(agent_runs: Iterable[AgentRun]) -> List[AgentRun]:
162
+ """Clone a sequence of `AgentRun` objects with randomized IDs.
163
+
164
+ Args:
165
+ agent_runs: Iterable of `AgentRun` instances to clone.
166
+
167
+ Returns:
168
+ A list of cloned `AgentRun` instances with fresh IDs and consistent references.
169
+ """
170
+ return [clone_agent_run_with_random_ids(ar) for ar in agent_runs]
@@ -0,0 +1,21 @@
1
+ from docent.judges.impl import BaseJudge, MajorityVotingJudge, MultiReflectionJudge
2
+ from docent.judges.types import (
3
+ JudgeResult,
4
+ JudgeResultCompletionCallback,
5
+ JudgeResultWithCitations,
6
+ ResultType,
7
+ Rubric,
8
+ )
9
+
10
+ __all__ = [
11
+ # Judges
12
+ "MajorityVotingJudge",
13
+ "MultiReflectionJudge",
14
+ "BaseJudge",
15
+ # Types
16
+ "Rubric",
17
+ "JudgeResult",
18
+ "JudgeResultWithCitations",
19
+ "JudgeResultCompletionCallback",
20
+ "ResultType",
21
+ ]
docent/judges/impl.py ADDED
@@ -0,0 +1,222 @@
1
+ import json
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any
4
+
5
+ from docent._llm_util.data_models.llm_output import LLMOutput
6
+ from docent._llm_util.data_models.simple_svc import BaseLLMService, SimpleLLMService
7
+ from docent._log_util import get_logger
8
+ from docent.data_models.agent_run import AgentRun
9
+ from docent.judges.types import JudgeResult, ResultType, Rubric
10
+ from docent.judges.util.parse_output import parse_and_validate_llm_output
11
+ from docent.judges.util.voting import find_modal_result, get_agreement_keys
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class BaseJudge(ABC):
17
+ def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
18
+ self.cfg = cfg
19
+ self.llm_svc = llm_svc
20
+
21
+ @abstractmethod
22
+ async def __call__(self, agent_run: AgentRun, *args: Any, **kwargs: Any) -> JudgeResult | None:
23
+ """Returns None if all rollouts failed to produce a valid output."""
24
+
25
+
26
+ class MajorityVotingJudge(BaseJudge):
27
+ """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
28
+
29
+ def __init__(
30
+ self,
31
+ cfg: Rubric,
32
+ n_rollouts_per_input: int,
33
+ llm_svc: BaseLLMService = SimpleLLMService(),
34
+ ):
35
+ super().__init__(cfg, llm_svc)
36
+ self.n_rollouts_per_input = n_rollouts_per_input
37
+
38
+ async def __call__(
39
+ self,
40
+ agent_run: AgentRun,
41
+ max_concurrency: int = 10,
42
+ ) -> JudgeResult | None:
43
+ async def _validation_callback(batch_index: int, llm_output: LLMOutput):
44
+ parse_and_validate_llm_output(llm_output, self.cfg.output_schema, agent_run)
45
+
46
+ prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
47
+ outputs = await self.llm_svc.get_completions(
48
+ inputs=[prompt for _ in range(self.n_rollouts_per_input)],
49
+ model_options=[self.cfg.judge_model],
50
+ max_new_tokens=16384,
51
+ timeout=180.0,
52
+ use_cache=False,
53
+ validation_callback=_validation_callback,
54
+ max_concurrency=max_concurrency,
55
+ )
56
+
57
+ # Process each rollout independently
58
+ indep_results: list[dict[str, Any]] = []
59
+ for output in outputs:
60
+ if validated_output := parse_and_validate_llm_output(
61
+ output, self.cfg.output_schema, agent_run
62
+ ):
63
+ indep_results.append(validated_output)
64
+
65
+ if not indep_results:
66
+ return None
67
+
68
+ # Get a list of the keys that we want to measure agreement on
69
+ agreement_keys = get_agreement_keys(self.cfg.output_schema)
70
+
71
+ # Find the result that best matches modal values
72
+ final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
73
+ indep_results, agreement_keys
74
+ )
75
+ final_output = indep_results[final_max_idx]
76
+
77
+ return JudgeResult(
78
+ agent_run_id=agent_run.id,
79
+ rubric_id=self.cfg.id,
80
+ rubric_version=self.cfg.version,
81
+ output=final_output,
82
+ result_metadata={
83
+ "agt_keys": agreement_keys,
84
+ # Final measurements
85
+ "final_results": indep_results,
86
+ "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
87
+ "final_max_idx": final_max_idx,
88
+ },
89
+ result_type=ResultType.DIRECT_RESULT,
90
+ )
91
+
92
+
93
+ class MultiReflectionJudge(BaseJudge):
94
+ """Rolls out the judge multiple times, then uses reflection to determine the final result."""
95
+
96
+ def __init__(
97
+ self,
98
+ cfg: Rubric,
99
+ n_rollouts_per_input: int,
100
+ llm_svc: BaseLLMService = SimpleLLMService(),
101
+ ):
102
+ super().__init__(cfg, llm_svc)
103
+ self.n_rollouts_per_input = n_rollouts_per_input
104
+
105
+ async def __call__(
106
+ self,
107
+ agent_run: AgentRun,
108
+ max_concurrency: int = 10,
109
+ ) -> JudgeResult | None:
110
+ rubric = self.cfg
111
+
112
+ async def _validation_callback(batch_index: int, llm_output: LLMOutput):
113
+ parse_and_validate_llm_output(llm_output, rubric.output_schema, agent_run)
114
+
115
+ # Run several independent rollouts
116
+ prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
117
+ outputs = await self.llm_svc.get_completions(
118
+ inputs=[prompt for _ in range(self.n_rollouts_per_input)],
119
+ model_options=[rubric.judge_model],
120
+ max_new_tokens=16384,
121
+ timeout=180.0,
122
+ use_cache=False,
123
+ validation_callback=_validation_callback,
124
+ max_concurrency=max_concurrency,
125
+ )
126
+
127
+ # Process each rollout
128
+ indep_results: list[dict[str, Any]] = []
129
+ for output in outputs:
130
+ if output.first_text is None:
131
+ continue
132
+ if v_output := parse_and_validate_llm_output(output, rubric.output_schema, agent_run):
133
+ indep_results.append(v_output)
134
+
135
+ if not indep_results:
136
+ return None
137
+
138
+ # Compute initial modes
139
+ agreement_keys = get_agreement_keys(rubric.output_schema)
140
+ indep_max_idx, indep_agt_key_modes_and_counts = find_modal_result(
141
+ indep_results, agreement_keys
142
+ )
143
+
144
+ def _get_reflection_prompt(cur_index: int):
145
+ # Current result
146
+ result = indep_results[cur_index]
147
+ # Get other results (excluding the current one)
148
+ other_results = [r for j, r in enumerate(indep_results) if j != cur_index]
149
+
150
+ # Create the reflection message
151
+ other_results_text = "\n\n".join(
152
+ [f"Answer {j+1}:\n{json.dumps(r, indent=2)}" for j, r in enumerate(other_results)]
153
+ )
154
+
155
+ reflection_instruction = (
156
+ f"Here are {len(other_results)} other independent answers to the same rubric evaluation:\n\n"
157
+ f"{other_results_text}\n\n"
158
+ f"Please reflect on these other answers and your own answer. "
159
+ f"Consider if any of them have identified important aspects you missed, or if there are disagreements that should be resolved. "
160
+ f"Then provide your final answer in the same JSON format as before."
161
+ )
162
+
163
+ # Construct the multi-message prompt
164
+ # 1. Original user message
165
+ # 2. Assistant message with the rollout's result
166
+ # 3. New user message asking for reflection
167
+ return [
168
+ *prompt, # Original user message(s)
169
+ {"role": "assistant", "content": json.dumps(result, indent=2)},
170
+ {"role": "user", "content": reflection_instruction},
171
+ ]
172
+
173
+ final_results = indep_results.copy() # Shallow copy
174
+ if len(indep_results) > 1:
175
+ # Ask the judge to reflect on the others' results
176
+ reflection_outputs = await self.llm_svc.get_completions(
177
+ inputs=[_get_reflection_prompt(i) for i in range(len(indep_results))],
178
+ model_options=[rubric.judge_model],
179
+ max_new_tokens=16384,
180
+ timeout=180.0,
181
+ use_cache=False,
182
+ validation_callback=_validation_callback,
183
+ max_concurrency=max_concurrency,
184
+ )
185
+
186
+ # Process reflection outputs in the same way as the initial rollouts
187
+ reflected_results: list[dict[str, Any]] = []
188
+ for output in reflection_outputs:
189
+ if output.first_text is None:
190
+ continue
191
+ if v_output := parse_and_validate_llm_output(
192
+ output, rubric.output_schema, agent_run
193
+ ):
194
+ reflected_results.append(v_output)
195
+
196
+ # Use reflected results if we got any, otherwise fall back to original results
197
+ if reflected_results:
198
+ final_results = reflected_results
199
+ else:
200
+ logger.warning("No reflected results found, falling back to original results")
201
+
202
+ final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
203
+ final_results, agreement_keys
204
+ )
205
+ return JudgeResult(
206
+ agent_run_id=agent_run.id,
207
+ rubric_id=rubric.id,
208
+ rubric_version=rubric.version,
209
+ output=final_results[final_max_idx],
210
+ result_metadata={
211
+ "agt_keys": agreement_keys,
212
+ # Final measurements
213
+ "final_results": final_results,
214
+ "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
215
+ "final_max_idx": final_max_idx,
216
+ # Also include initial measurements
217
+ "indep_results": indep_results,
218
+ "indep_max_idx": indep_max_idx,
219
+ "indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
220
+ },
221
+ result_type=ResultType.DIRECT_RESULT,
222
+ )
docent/judges/types.py ADDED
@@ -0,0 +1,240 @@
1
+ import enum
2
+ import json
3
+ from string import Formatter
4
+ from typing import Any, Callable, Protocol
5
+ from uuid import uuid4
6
+
7
+ from pydantic import BaseModel, Field, field_serializer, field_validator
8
+
9
+ from docent._llm_util.providers.preference_types import PUBLIC_PROVIDER_PREFERENCES, ModelOption
10
+ from docent._log_util import get_logger
11
+ from docent.data_models.agent_run import AgentRun
12
+ from docent.data_models.citation import parse_citations
13
+ from docent.data_models.transcript import TEXT_RANGE_CITE_INSTRUCTION
14
+ from docent.judges.util.meta_schema import validate_judge_result_schema
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
19
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
20
+
21
+ Rubric:
22
+ {rubric}
23
+
24
+ Agent run:
25
+ {agent_run}
26
+
27
+ Your response should convey your judgment of the agent run according to the criteria given in the rubric provided above. Your entire response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
28
+
29
+ The JSON object you produce must adhere to the following schema:
30
+ {output_schema}
31
+
32
+ {citation_instructions}
33
+ """.strip()
34
+
35
+ DEFAULT_CITATION_INSTRUCTIONS = f"""
36
+ For strings which require citations (according to the `citations: True` property), you must also follow these instructions:
37
+ {TEXT_RANGE_CITE_INSTRUCTION}
38
+ """.strip()
39
+
40
+ DEFAULT_JUDGE_OUTPUT_SCHEMA = {
41
+ "type": "object",
42
+ "properties": {
43
+ "label": {"type": "string", "enum": ["match", "no match"]},
44
+ "explanation": {"type": "string", "citations": True},
45
+ },
46
+ # Require these properties to be present
47
+ "required": ["label", "explanation"],
48
+ # Allow additional properties though, as their presence is not breaking
49
+ }
50
+
51
+ DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
52
+
53
+
54
+ class Rubric(BaseModel):
55
+ """TODO(mengk): this should really be called JudgeConfig,
56
+ but temporarily keeping this for consistency with docent_core."""
57
+
58
+ class Config:
59
+ frozen = True
60
+
61
+ # Primary key
62
+ id: str = Field(default_factory=lambda: str(uuid4()))
63
+ version: int = 1
64
+
65
+ # What the judge actually does
66
+ rubric_text: str
67
+
68
+ # Default instructions for the judge
69
+ system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
70
+ citation_instructions: str = DEFAULT_CITATION_INSTRUCTIONS
71
+ output_schema: dict[str, Any] = DEFAULT_JUDGE_OUTPUT_SCHEMA
72
+
73
+ # How to run the judge
74
+ judge_model: ModelOption = DEFAULT_JUDGE_MODEL
75
+
76
+ def materialize_system_prompt(self, agent_run: AgentRun) -> str:
77
+ """Construct the full prompt text for rubric evaluation.
78
+
79
+ This is the canonical implementation of prompt construction - use this function
80
+ anywhere you need to construct a rubric evaluation prompt (including cost estimation).
81
+ """
82
+
83
+ output_schema_text = json.dumps(self.output_schema, indent=2)
84
+
85
+ # We've already validated that the system prompt template has these keys
86
+ prompt = self.system_prompt_template.format(
87
+ rubric=self.rubric_text,
88
+ agent_run=agent_run.to_text_new(),
89
+ output_schema=output_schema_text,
90
+ # Only include citation instructions if the schema requests citations
91
+ citation_instructions=(
92
+ self.citation_instructions if _schema_requests_citations(self.output_schema) else ""
93
+ ),
94
+ ).strip()
95
+
96
+ return prompt
97
+
98
+ @field_validator("system_prompt_template")
99
+ @classmethod
100
+ def validate_system_prompt_template(cls, system_prompt_template: str):
101
+ # Extract all field names from the template
102
+ formatter = Formatter()
103
+ field_names = {
104
+ field_name
105
+ for _, field_name, _, _ in formatter.parse(system_prompt_template)
106
+ if field_name is not None
107
+ }
108
+
109
+ # Check for required fields
110
+ required_fields = {"agent_run", "output_schema", "rubric", "citation_instructions"}
111
+ missing_fields = required_fields - field_names
112
+
113
+ if missing_fields:
114
+ raise ValueError(
115
+ f"system_prompt_template must contain the following placeholders: {missing_fields}"
116
+ )
117
+
118
+ return system_prompt_template
119
+
120
+ @field_validator("output_schema")
121
+ @classmethod
122
+ def validate_output_schema(cls, output_schema: dict[str, Any]):
123
+ """
124
+ Raises:
125
+ jsonschema.ValidationError: If the schema is invalid
126
+ jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
127
+ """
128
+ validate_judge_result_schema(output_schema)
129
+ return output_schema
130
+
131
+
132
+ class ResultType(enum.Enum):
133
+ """Enum for the type of result that a judge result can have."""
134
+
135
+ DIRECT_RESULT = "direct_result"
136
+ NEAR_MISS = "near_miss"
137
+
138
+
139
+ class JudgeResult(BaseModel):
140
+ class Config:
141
+ frozen = True
142
+
143
+ id: str = Field(default_factory=lambda: str(uuid4()))
144
+ agent_run_id: str
145
+ rubric_id: str
146
+ rubric_version: int
147
+
148
+ # Outputs
149
+ output: dict[str, Any]
150
+ result_metadata: dict[str, Any] | None = None
151
+ result_type: ResultType
152
+
153
+ # Deprecated
154
+ value: str | None = None
155
+
156
+ @field_serializer("result_type")
157
+ def serialize_result_type(self, result_type: ResultType) -> str:
158
+ return result_type.value
159
+
160
+
161
+ class JudgeResultWithCitations(JudgeResult):
162
+ @classmethod
163
+ def from_judge_result(
164
+ cls, result: JudgeResult, schema: dict[str, Any]
165
+ ) -> "JudgeResultWithCitations":
166
+ """Judge result must be validated against the schema before calling this function!"""
167
+
168
+ def _parse_citation_string(output: str) -> dict[str, Any]:
169
+ text, citations = parse_citations(output)
170
+ return {"text": text, "citations": citations}
171
+
172
+ data = result.model_dump()
173
+ try:
174
+ data["output"] = traverse_schema_and_transform(
175
+ data["output"], schema, _parse_citation_string
176
+ )
177
+ except Exception as e:
178
+ logger.error(f"Failed to parse citations: {e}")
179
+ logger.error(f"Output: {data['output']}")
180
+ data["output"] = {"raw": data["output"]}
181
+ return cls(**data)
182
+
183
+
184
+ class JudgeResultCompletionCallback(Protocol):
185
+ """Called when some batch of judge results is completed.
186
+ Supports batched calls for cases where many results are pre-computed.
187
+ This avoids invoking the callback separately for each datapoint.
188
+ """
189
+
190
+ async def __call__(
191
+ self,
192
+ batch_index: int,
193
+ judge_results: list[JudgeResult] | None,
194
+ ) -> None: ...
195
+
196
+
197
+ def traverse_schema_and_transform(
198
+ output: Any,
199
+ schema: dict[str, Any],
200
+ citation_string_handler: Callable[[str], Any],
201
+ ) -> Any:
202
+ """Recursively traverse output based on schema, applying citation_string_handler to citation strings."""
203
+ if schema.get("type") == "string" and schema.get("citations"): # type: ignore
204
+ return citation_string_handler(output)
205
+ elif schema.get("type") == "object":
206
+ properties: dict[str, Any] = schema.get("properties", {})
207
+ result: dict[str, Any] = {}
208
+ for key in properties:
209
+ if key in output:
210
+ result[key] = traverse_schema_and_transform(
211
+ output[key], properties[key], citation_string_handler
212
+ )
213
+ return result
214
+ elif schema.get("type") == "array":
215
+ item_schema: dict[str, Any] = schema.get("items", {})
216
+ return [
217
+ traverse_schema_and_transform(item, item_schema, citation_string_handler)
218
+ for item in output
219
+ ]
220
+ else:
221
+ return output
222
+
223
+
224
+ def _schema_requests_citations(schema: dict[str, Any]) -> bool:
225
+ """Check if any field in the schema requests citations by having 'citations': 'true'."""
226
+
227
+ def _check_field(field_schema: Any) -> bool:
228
+ if isinstance(field_schema, dict):
229
+ if field_schema.get("citations"): # type: ignore
230
+ return True
231
+ for value in field_schema.values(): # type: ignore
232
+ if isinstance(value, dict) and _check_field(value):
233
+ return True
234
+ elif isinstance(value, list):
235
+ for item in value: # type: ignore
236
+ if isinstance(item, dict) and _check_field(item):
237
+ return True
238
+ return False
239
+
240
+ return _check_field(schema)