docent-python 0.1.41a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (59) hide show
  1. docent/__init__.py +4 -0
  2. docent/_llm_util/__init__.py +0 -0
  3. docent/_llm_util/data_models/__init__.py +0 -0
  4. docent/_llm_util/data_models/exceptions.py +48 -0
  5. docent/_llm_util/data_models/llm_output.py +331 -0
  6. docent/_llm_util/llm_cache.py +193 -0
  7. docent/_llm_util/llm_svc.py +472 -0
  8. docent/_llm_util/model_registry.py +134 -0
  9. docent/_llm_util/providers/__init__.py +0 -0
  10. docent/_llm_util/providers/anthropic.py +537 -0
  11. docent/_llm_util/providers/common.py +41 -0
  12. docent/_llm_util/providers/google.py +530 -0
  13. docent/_llm_util/providers/openai.py +745 -0
  14. docent/_llm_util/providers/openrouter.py +375 -0
  15. docent/_llm_util/providers/preference_types.py +104 -0
  16. docent/_llm_util/providers/provider_registry.py +164 -0
  17. docent/_log_util/__init__.py +3 -0
  18. docent/_log_util/logger.py +141 -0
  19. docent/data_models/__init__.py +14 -0
  20. docent/data_models/_tiktoken_util.py +91 -0
  21. docent/data_models/agent_run.py +473 -0
  22. docent/data_models/chat/__init__.py +37 -0
  23. docent/data_models/chat/content.py +56 -0
  24. docent/data_models/chat/message.py +191 -0
  25. docent/data_models/chat/tool.py +109 -0
  26. docent/data_models/citation.py +187 -0
  27. docent/data_models/formatted_objects.py +84 -0
  28. docent/data_models/judge.py +17 -0
  29. docent/data_models/metadata_util.py +16 -0
  30. docent/data_models/regex.py +56 -0
  31. docent/data_models/transcript.py +305 -0
  32. docent/data_models/util.py +170 -0
  33. docent/judges/__init__.py +23 -0
  34. docent/judges/analysis.py +77 -0
  35. docent/judges/impl.py +587 -0
  36. docent/judges/runner.py +129 -0
  37. docent/judges/stats.py +205 -0
  38. docent/judges/types.py +320 -0
  39. docent/judges/util/forgiving_json.py +108 -0
  40. docent/judges/util/meta_schema.json +86 -0
  41. docent/judges/util/meta_schema.py +29 -0
  42. docent/judges/util/parse_output.py +68 -0
  43. docent/judges/util/voting.py +139 -0
  44. docent/loaders/load_inspect.py +215 -0
  45. docent/py.typed +0 -0
  46. docent/samples/__init__.py +3 -0
  47. docent/samples/load.py +9 -0
  48. docent/samples/log.eval +0 -0
  49. docent/samples/tb_airline.json +1 -0
  50. docent/sdk/__init__.py +0 -0
  51. docent/sdk/agent_run_writer.py +317 -0
  52. docent/sdk/client.py +1186 -0
  53. docent/sdk/llm_context.py +432 -0
  54. docent/trace.py +2741 -0
  55. docent/trace_temp.py +1086 -0
  56. docent_python-0.1.41a0.dist-info/METADATA +33 -0
  57. docent_python-0.1.41a0.dist-info/RECORD +59 -0
  58. docent_python-0.1.41a0.dist-info/WHEEL +4 -0
  59. docent_python-0.1.41a0.dist-info/licenses/LICENSE.md +13 -0
@@ -0,0 +1,305 @@
1
+ import sys
2
+ import textwrap
3
+ from datetime import datetime
4
+ from typing import Any, Iterable
5
+ from uuid import uuid4
6
+
7
+ import yaml
8
+ from pydantic import BaseModel, Field, field_validator
9
+ from pydantic_core import to_jsonable_python
10
+
11
+ from docent.data_models._tiktoken_util import (
12
+ get_token_count,
13
+ group_messages_into_ranges,
14
+ truncate_to_token_limit,
15
+ )
16
+ from docent.data_models.chat import AssistantMessage, ChatMessage, ContentReasoning
17
+ from docent.data_models.citation import RANGE_BEGIN, RANGE_END
18
+ from docent.data_models.metadata_util import dump_metadata
19
+
20
+ # Template for formatting individual transcript blocks
21
+ TRANSCRIPT_BLOCK_TEMPLATE = """
22
+ <|{index_label}; role: {role}|>
23
+ {content}
24
+ </|{index_label}; role: {role}|>
25
+ """.strip()
26
+
27
+ # Instructions for citing single transcript blocks
28
+ TEXT_RANGE_CITE_INSTRUCTION = f"""Anytime you quote the transcript, or refer to something that happened in the transcript, or make any claim about the transcript, add an inline citation. Each transcript and each block has a unique index. Cite the relevant indices in brackets. For example, to cite the entirety of transcript 0, block 1, write [T0B1].
29
+
30
+ A citation may include a specific range of text within a block. Use {RANGE_BEGIN} and {RANGE_END} to mark the specific range of text. Add it after the block ID separated by a colon. For example, to cite the part of transcript 0, block 1, where the agent says "I understand the task", write [T0B1:{RANGE_BEGIN}I understand the task{RANGE_END}]. Citations must follow this exact format. The markers {RANGE_BEGIN} and {RANGE_END} must be used ONLY inside the brackets of a citation.
31
+
32
+ - You may cite a top-level key in the agent run metadata like this: [M.task_description].
33
+ - You may cite a top-level key in transcript metadata. For example, for transcript 0: [T0M.start_time].
34
+ - You may cite a top-level key in message metadata for a block. For example, for transcript 0, block 1: [T0B1M.status].
35
+ - You may not cite nested keys. For example, [T0B1M.status.code] is invalid.
36
+ - Within a top-level metadata key you may cite a range of text that appears in the value. For example, [T0B1M.status:{RANGE_BEGIN}"running":false{RANGE_END}].
37
+
38
+ Important notes:
39
+ - You must include the full content of the text range {RANGE_BEGIN} and {RANGE_END}, EXACTLY as it appears in the transcript, word-for-word, including any markers or punctuation that appear in the middle of the text.
40
+ - Citations must be as specific as possible. This means you should usually cite a specific text range within a block.
41
+ - A citation is not a quote. For brevity, text ranges will not be rendered inline. The user will have to click on the citation to see the full text range.
42
+ - Citations are self-contained. Do NOT label them as citation or evidence. Just insert the citation by itself at the appropriate place in the text.
43
+ - Citations must come immediately after the part of a claim that they support. This may be in the middle of a sentence.
44
+ - Each pair of brackets must contain only one citation. To cite multiple blocks, use multiple pairs of brackets, like [T0B0] [T0B1].
45
+ - Outside of citations, do not refer to transcript numbers or block numbers.
46
+ - Outside of citations, avoid quoting or paraphrasing the transcript.
47
+ """
48
+
49
+ BLOCK_CITE_INSTRUCTION = """Each transcript and each block has a unique index. Cite the relevant indices in brackets when relevant, like [T<idx>B<idx>]. Use multiple tags to cite multiple blocks, like [T<idx1>B<idx1>][T<idx2>B<idx2>]. Remember to cite specific blocks and NOT action units."""
50
+
51
+
52
+ def format_chat_message(
53
+ message: ChatMessage,
54
+ index_label: str,
55
+ ) -> str:
56
+ cur_content = ""
57
+
58
+ # Add reasoning at beginning if applicable
59
+ if isinstance(message, AssistantMessage) and message.content:
60
+ for content in message.content:
61
+ if isinstance(content, ContentReasoning):
62
+ cur_content = f"<reasoning>\n{content.reasoning}\n</reasoning>\n"
63
+
64
+ # Main content text
65
+ cur_content += message.text
66
+
67
+ # Update content in case there's a view
68
+ if isinstance(message, AssistantMessage) and message.tool_calls:
69
+ for tool_call in message.tool_calls:
70
+ if tool_call.view:
71
+ cur_content += f"\n<tool call>\n{tool_call.view.content}\n</tool call>"
72
+ else:
73
+ args = ", ".join([f"{k}={v}" for k, v in tool_call.arguments.items()])
74
+ cur_content += f"\n<tool call>\n{tool_call.function}({args})\n</tool call>"
75
+
76
+ if message.metadata:
77
+ metadata_text = dump_metadata(message.metadata)
78
+ if metadata_text is not None:
79
+ cur_content += f"\n<|message metadata|>\n{metadata_text}\n</|message metadata|>"
80
+
81
+ return TRANSCRIPT_BLOCK_TEMPLATE.format(
82
+ index_label=index_label, role=message.role, content=cur_content
83
+ )
84
+
85
+
86
+ class TranscriptGroup(BaseModel):
87
+ """Represents a group of transcripts that are logically related.
88
+
89
+ A transcript group can contain multiple transcripts and can have a hierarchical
90
+ structure with parent groups. This is useful for organizing transcripts into
91
+ logical units like experiments, tasks, or sessions.
92
+
93
+ Attributes:
94
+ id: Unique identifier for the transcript group, auto-generated by default.
95
+ name: Optional human-readable name for the transcript group.
96
+ description: Optional description of the transcript group.
97
+ collection_id: ID of the collection this transcript group belongs to.
98
+ agent_run_id: ID of the agent run this transcript group belongs to.
99
+ parent_transcript_group_id: Optional ID of the parent transcript group.
100
+ metadata: Additional structured metadata about the transcript group.
101
+ """
102
+
103
+ id: str = Field(default_factory=lambda: str(uuid4()))
104
+ name: str | None = None
105
+ description: str | None = None
106
+ agent_run_id: str
107
+ parent_transcript_group_id: str | None = None
108
+ created_at: datetime | None = None
109
+ metadata: dict[str, Any] = Field(default_factory=dict)
110
+
111
+ @field_validator("metadata", mode="before")
112
+ @classmethod
113
+ def _validate_metadata_type(cls, v: Any) -> Any:
114
+ if v is not None and not isinstance(v, dict):
115
+ raise ValueError(f"metadata must be a dictionary, got {type(v).__name__}")
116
+ return v # type: ignore
117
+
118
+ def to_text_new(self, children_text: str, indent: int = 0) -> str:
119
+ """Render this transcript group with its children and metadata.
120
+
121
+ Metadata appears below the rendered children content.
122
+
123
+ Args:
124
+ children_text: Pre-rendered text of this group's children (groups/transcripts).
125
+ indent: Number of spaces to indent the rendered output.
126
+
127
+ Returns:
128
+ str: XML-like wrapped text including the group's metadata.
129
+ """
130
+ # Prepare YAML metadata
131
+ metadata_text = dump_metadata(self.metadata)
132
+ if metadata_text is not None:
133
+ if indent > 0:
134
+ metadata_text = textwrap.indent(metadata_text, " " * indent)
135
+ inner = f"{children_text}\n<|{self.name} metadata|>\n{metadata_text}\n</|{self.name} metadata|>"
136
+ else:
137
+ inner = children_text
138
+
139
+ # Compose final text: content first, then metadata, all inside the group wrapper
140
+ if indent > 0:
141
+ inner = textwrap.indent(inner, " " * indent)
142
+ return f"<|{self.name}|>\n{inner}\n</|{self.name}|>"
143
+
144
+
145
+ class Transcript(BaseModel):
146
+ """Represents a transcript of messages in a conversation with an AI agent.
147
+
148
+ A transcript contains a sequence of messages exchanged between different roles
149
+ (system, user, assistant, tool) and provides methods to organize these messages
150
+ into logical units of action.
151
+
152
+ Attributes:
153
+ id: Unique identifier for the transcript, auto-generated by default.
154
+ name: Optional human-readable name for the transcript.
155
+ description: Optional description of the transcript.
156
+ transcript_group_id: Optional ID of the transcript group this transcript belongs to.
157
+ messages: List of chat messages in the transcript.
158
+ metadata: Additional structured metadata about the transcript.
159
+ """
160
+
161
+ id: str = Field(default_factory=lambda: str(uuid4()))
162
+ name: str | None = None
163
+ description: str | None = None
164
+ transcript_group_id: str | None = None
165
+ created_at: datetime | None = None
166
+
167
+ messages: list[ChatMessage]
168
+ metadata: dict[str, Any] = Field(default_factory=dict)
169
+
170
+ @field_validator("metadata", mode="before")
171
+ @classmethod
172
+ def _validate_metadata_type(cls, v: Any) -> Any:
173
+ if v is not None and not isinstance(v, dict):
174
+ raise ValueError(f"metadata must be a dict, got {type(v).__name__}")
175
+ return v # type: ignore
176
+
177
+ def __init__(self, *args: Any, **kwargs: Any):
178
+ super().__init__(*args, **kwargs)
179
+
180
+ def set_messages(self, messages: list[ChatMessage]):
181
+ """Set the messages in the transcript and recompute units of action.
182
+
183
+ Args:
184
+ messages: The new list of chat messages to set.
185
+ """
186
+ self.messages = messages
187
+
188
+ def _generate_formatted_blocks(
189
+ self,
190
+ transcript_idx: int = 0,
191
+ ) -> list[str]:
192
+ """Generate formatted blocks for transcript representation.
193
+
194
+ Args:
195
+ transcript_idx: Index of the transcript
196
+ agent_run_idx: Optional agent run index
197
+ use_action_units: If True, group messages into action units. If False, use individual blocks.
198
+
199
+ Returns:
200
+ list[str]: List of formatted blocks
201
+ """
202
+ # Individual message blocks
203
+ blocks: list[str] = []
204
+ for msg_idx, message in enumerate(self.messages):
205
+ blocks.append(
206
+ format_chat_message(
207
+ message,
208
+ index_label=f"T{transcript_idx}B{msg_idx}",
209
+ )
210
+ )
211
+
212
+ return blocks
213
+
214
+ def to_str(
215
+ self,
216
+ token_limit: int = sys.maxsize,
217
+ transcript_idx: int = 0,
218
+ ) -> list[str]:
219
+ """Core implementation for string representation with token limits.
220
+
221
+ Args:
222
+ token_limit: Maximum tokens per returned string
223
+ transcript_idx: Index of the transcript
224
+
225
+ Returns:
226
+ list[str]: List of strings, each within token limit
227
+ """
228
+ blocks: list[str] = self._generate_formatted_blocks(transcript_idx)
229
+ blocks_str = "\n".join(blocks)
230
+
231
+ # Gather metadata
232
+ metadata_obj = to_jsonable_python(self.metadata)
233
+ yaml_width = float("inf")
234
+ block_str = f"<blocks>\n{blocks_str}\n</blocks>\n"
235
+ metadata_str = f"<|transcript metadata|>\n{yaml.dump(metadata_obj, width=yaml_width)}\n</|transcript metadata|>"
236
+
237
+ if token_limit == sys.maxsize:
238
+ return [f"{block_str}" f"{metadata_str}"]
239
+
240
+ metadata_token_count = get_token_count(metadata_str)
241
+ block_token_count = get_token_count(block_str)
242
+
243
+ if metadata_token_count + block_token_count <= token_limit:
244
+ return [f"{block_str}" f"{metadata_str}"]
245
+ else:
246
+ results: list[str] = []
247
+ block_token_counts = [get_token_count(block) for block in blocks]
248
+ ranges = group_messages_into_ranges(
249
+ block_token_counts, metadata_token_count, token_limit
250
+ )
251
+ for msg_range in ranges:
252
+ if msg_range.include_metadata:
253
+ cur_blocks = "\n".join(blocks[msg_range.start : msg_range.end])
254
+ results.append(f"<blocks>\n{cur_blocks}\n</blocks>\n" f"{metadata_str}")
255
+ else:
256
+ assert (
257
+ msg_range.end == msg_range.start + 1
258
+ ), "Ranges without metadata should be a single message"
259
+ result = str(blocks[msg_range.start])
260
+ if msg_range.num_tokens > token_limit - 10:
261
+ result = truncate_to_token_limit(result, token_limit - 10)
262
+ results.append(f"<blocks>\n{result}\n</blocks>\n")
263
+
264
+ return results
265
+
266
+ ##############################
267
+ # New text rendering methods #
268
+ ##############################
269
+
270
+ def _enumerate_messages(self) -> Iterable[tuple[int, ChatMessage]]:
271
+ """Yield (index, message) tuples for rendering.
272
+
273
+ Override in subclasses to customize index assignment.
274
+ """
275
+ return enumerate(self.messages)
276
+
277
+ def to_text_new(self, transcript_alias: int | str = 0, indent: int = 0) -> str:
278
+
279
+ if isinstance(transcript_alias, int):
280
+ transcript_alias = f"T{transcript_alias}"
281
+
282
+ # Format individual message blocks
283
+ blocks: list[str] = []
284
+ for msg_idx, message in self._enumerate_messages():
285
+ block_label = f"{transcript_alias}B{msg_idx}"
286
+ block_text = format_chat_message(message, block_label)
287
+ blocks.append(block_text)
288
+ blocks_str = "\n".join(blocks)
289
+ if indent > 0:
290
+ blocks_str = textwrap.indent(blocks_str, " " * indent)
291
+
292
+ content_str = f"<|{transcript_alias} blocks|>\n{blocks_str}\n</|{transcript_alias} blocks|>"
293
+
294
+ # Gather metadata and add to content
295
+ metadata_text = dump_metadata(self.metadata)
296
+ if metadata_text is not None:
297
+ if indent > 0:
298
+ metadata_text = textwrap.indent(metadata_text, " " * indent)
299
+ metadata_label = f"{transcript_alias}M"
300
+ content_str += f"\n<|transcript metadata {metadata_label}|>\n{metadata_text}\n</|transcript metadata {metadata_label}|>"
301
+
302
+ # Format content and return
303
+ if indent > 0:
304
+ content_str = textwrap.indent(content_str, " " * indent)
305
+ return f"<|transcript {transcript_alias}|>\n{content_str}\n</|transcript {transcript_alias}|>\n"
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Iterable, List, TypeVar
4
+ from uuid import uuid4
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from docent.data_models.agent_run import AgentRun
9
+
10
+ T = TypeVar("T", bound=BaseModel)
11
+
12
+
13
+ def _deep_copy_model(model: T) -> T:
14
+ """Create a deep copy of a Pydantic v2 model.
15
+
16
+ Using `model_copy(deep=True)` ensures nested models are fully copied and
17
+ mutations do not affect the original instance.
18
+ """
19
+ return model.model_copy(deep=True)
20
+
21
+
22
+ def clone_agent_run_with_random_ids(agent_run: AgentRun) -> AgentRun:
23
+ """Clone an `AgentRun`, randomizing all IDs and fixing internal references.
24
+
25
+ The following transformations are performed on the cloned instance:
26
+ - Assign a new `AgentRun.id`.
27
+ - Assign new `Transcript.id` values and update any references to them (none today).
28
+ - Assign new `TranscriptGroup.id` values.
29
+ - Update `Transcript.transcript_group_id` to the new group IDs where applicable.
30
+ - Update `TranscriptGroup.agent_run_id` to the new `AgentRun.id`.
31
+ - Update `TranscriptGroup.parent_transcript_group_id` to the new group IDs where applicable.
32
+
33
+ Notes:
34
+ - If a `parent_transcript_group_id` or `transcript_group_id` references a group id that
35
+ is not present in the cloned run, the reference is left unchanged (mirrors importer behavior).
36
+
37
+ Args:
38
+ agent_run: The source `AgentRun` to clone.
39
+
40
+ Returns:
41
+ A new, independent `AgentRun` instance with randomized identifiers and consistent references.
42
+ """
43
+ # Validate source integrity before cloning
44
+ # - No duplicate transcript or group IDs
45
+ # - All transcript.group references exist if set
46
+ # - All group.parent references exist if set
47
+ # - All group.agent_run_id match the source run id
48
+ src_transcript_ids = [str(t.id) for t in agent_run.transcripts]
49
+ if len(src_transcript_ids) != len(set(src_transcript_ids)):
50
+ raise ValueError("Duplicate transcript ids detected in source AgentRun")
51
+
52
+ src_group_ids = [str(g.id) for g in agent_run.transcript_groups]
53
+ if len(src_group_ids) != len(set(src_group_ids)):
54
+ raise ValueError("Duplicate transcript group ids detected in source AgentRun")
55
+
56
+ src_group_id_set = set(src_group_ids)
57
+ for t in agent_run.transcripts:
58
+ if t.transcript_group_id is not None and str(t.transcript_group_id) not in src_group_id_set:
59
+ raise ValueError(
60
+ f"Transcript {t.id} references missing transcript_group_id {t.transcript_group_id}"
61
+ )
62
+
63
+ for g in agent_run.transcript_groups:
64
+ if (
65
+ g.parent_transcript_group_id is not None
66
+ and str(g.parent_transcript_group_id) not in src_group_id_set
67
+ ):
68
+ raise ValueError(
69
+ f"TranscriptGroup {g.id} references missing parent_transcript_group_id {g.parent_transcript_group_id}"
70
+ )
71
+ if str(g.agent_run_id) != str(agent_run.id):
72
+ raise ValueError(
73
+ f"TranscriptGroup {g.id} has agent_run_id {g.agent_run_id} which does not match AgentRun.id {agent_run.id}"
74
+ )
75
+
76
+ # Deep copy first so we never mutate the caller's instance
77
+ new_run = _deep_copy_model(agent_run)
78
+
79
+ # 1) Randomize AgentRun ID
80
+ new_agent_run_id = str(uuid4())
81
+ old_to_new_transcript_id: Dict[str, str] = {}
82
+ old_to_new_group_id: Dict[str, str] = {}
83
+
84
+ # 2) Pre-compute new IDs for transcripts and transcript groups without mutating yet
85
+ for transcript in new_run.transcripts:
86
+ old_to_new_transcript_id[str(transcript.id)] = str(uuid4())
87
+
88
+ for group in new_run.transcript_groups:
89
+ old_to_new_group_id[str(group.id)] = str(uuid4())
90
+
91
+ # 3) Mutate transcript groups: set new id, set agent_run_id, remap parents
92
+ for group in new_run.transcript_groups:
93
+ old_group_id = str(group.id)
94
+
95
+ # Assign new group id
96
+ group.id = old_to_new_group_id.get(old_group_id, str(uuid4()))
97
+
98
+ # Ensure group points to the new agent run id
99
+ group.agent_run_id = new_agent_run_id
100
+
101
+ # Remap parent id; raise if unknown
102
+ if group.parent_transcript_group_id is not None:
103
+ old_parent_id = str(group.parent_transcript_group_id)
104
+ if old_parent_id not in old_to_new_group_id:
105
+ raise ValueError(
106
+ f"TranscriptGroup {old_group_id} parent_transcript_group_id {old_parent_id} not found in this AgentRun"
107
+ )
108
+ group.parent_transcript_group_id = old_to_new_group_id[old_parent_id]
109
+
110
+ # 4) Mutate transcripts: set new id, remap transcript_group_id
111
+ for transcript in new_run.transcripts:
112
+ old_transcript_id = str(transcript.id)
113
+
114
+ # Assign new transcript id
115
+ transcript.id = old_to_new_transcript_id.get(old_transcript_id, str(uuid4()))
116
+
117
+ # Remap group reference; raise if unknown
118
+ if transcript.transcript_group_id is not None:
119
+ old_group_id_ref = str(transcript.transcript_group_id)
120
+ if old_group_id_ref not in old_to_new_group_id:
121
+ raise ValueError(
122
+ f"Transcript {old_transcript_id} references transcript_group_id {old_group_id_ref} not found in this AgentRun"
123
+ )
124
+ transcript.transcript_group_id = old_to_new_group_id[old_group_id_ref]
125
+
126
+ # 5) Finally set the new run id
127
+ new_run.id = new_agent_run_id
128
+
129
+ # Post-validate integrity on the cloned run
130
+ new_group_ids = [str(g.id) for g in new_run.transcript_groups]
131
+ if len(new_group_ids) != len(set(new_group_ids)):
132
+ raise ValueError("Duplicate transcript group ids detected after cloning")
133
+ new_group_id_set = set(new_group_ids)
134
+
135
+ new_transcript_ids = [str(t.id) for t in new_run.transcripts]
136
+ if len(new_transcript_ids) != len(set(new_transcript_ids)):
137
+ raise ValueError("Duplicate transcript ids detected after cloning")
138
+
139
+ for t in new_run.transcripts:
140
+ if t.transcript_group_id is not None and str(t.transcript_group_id) not in new_group_id_set:
141
+ raise ValueError(
142
+ f"Transcript {t.id} references missing transcript_group_id {t.transcript_group_id} after cloning"
143
+ )
144
+
145
+ for g in new_run.transcript_groups:
146
+ if (
147
+ g.parent_transcript_group_id is not None
148
+ and str(g.parent_transcript_group_id) not in new_group_id_set
149
+ ):
150
+ raise ValueError(
151
+ f"TranscriptGroup {g.id} references missing parent_transcript_group_id {g.parent_transcript_group_id} after cloning"
152
+ )
153
+ if str(g.agent_run_id) != str(new_run.id):
154
+ raise ValueError(
155
+ f"TranscriptGroup {g.id} has agent_run_id {g.agent_run_id} which does not match cloned AgentRun.id {new_run.id}"
156
+ )
157
+
158
+ return new_run
159
+
160
+
161
+ def clone_agent_runs_with_random_ids(agent_runs: Iterable[AgentRun]) -> List[AgentRun]:
162
+ """Clone a sequence of `AgentRun` objects with randomized IDs.
163
+
164
+ Args:
165
+ agent_runs: Iterable of `AgentRun` instances to clone.
166
+
167
+ Returns:
168
+ A list of cloned `AgentRun` instances with fresh IDs and consistent references.
169
+ """
170
+ return [clone_agent_run_with_random_ids(ar) for ar in agent_runs]
@@ -0,0 +1,23 @@
1
+ from docent.judges.impl import BaseJudge, MajorityVotingJudge, MultiReflectionJudge
2
+ from docent.judges.types import (
3
+ JudgeResult,
4
+ JudgeResultCompletionCallback,
5
+ JudgeResultWithCitations,
6
+ JudgeVariant,
7
+ ResultType,
8
+ Rubric,
9
+ )
10
+
11
+ __all__ = [
12
+ # Judges
13
+ "MajorityVotingJudge",
14
+ "MultiReflectionJudge",
15
+ "BaseJudge",
16
+ # Types
17
+ "Rubric",
18
+ "JudgeResult",
19
+ "JudgeResultWithCitations",
20
+ "JudgeResultCompletionCallback",
21
+ "ResultType",
22
+ "JudgeVariant",
23
+ ]
@@ -0,0 +1,77 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import anyio
6
+ from pydantic import BaseModel
7
+ from pydantic_core import to_jsonable_python
8
+ from tqdm.auto import tqdm
9
+
10
+ from docent._log_util import get_logger
11
+ from docent.data_models.agent_run import AgentRun
12
+ from docent.judges.impl import BaseJudge
13
+ from docent.judges.util.voting import JudgeOutputDistribution
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class MultiReflectRollouts(BaseModel):
19
+ """Object is associated with a single agent run"""
20
+
21
+ agent_run_id: str
22
+
23
+ first_step_rollouts: list[dict[str, Any]]
24
+ first_step_rollout_metadata: list[dict[str, Any] | None]
25
+ # Each index in second_step_rollouts corresponds to an index in first_step_combinations
26
+ # Step 2 rollouts are computed by passing each step 1 combo into the judge several times
27
+ first_step_combinations: list[list[dict[str, Any]]] | None = None
28
+ second_step_rollouts: list[list[dict[str, Any]]] | None = None
29
+ second_step_rollout_metadata: list[list[dict[str, Any] | None]] | None = None
30
+
31
+ distributions: dict[str, JudgeOutputDistribution]
32
+
33
+
34
+ async def collect_judge_pvs(
35
+ judge: BaseJudge,
36
+ agent_runs: list[AgentRun],
37
+ *,
38
+ results_path: Path,
39
+ estimate_output_distrs_kwargs: dict[str, Any],
40
+ ):
41
+ if results_path.exists():
42
+ raise FileExistsError(f"Results path already exists: {results_path}")
43
+ results_path.parent.mkdir(parents=True, exist_ok=True)
44
+
45
+ results = dict[str, MultiReflectRollouts]()
46
+ persist_lock = anyio.Lock()
47
+ pbar = tqdm(total=len(agent_runs), desc="Processing agent runs")
48
+
49
+ async def _persist():
50
+ async with persist_lock:
51
+ with open(str(results_path), "w") as f:
52
+ json.dump(to_jsonable_python(results), f, indent=2)
53
+
54
+ async def _execute_for_agent_run(agent_run: AgentRun):
55
+ result = await judge.estimate_output_distrs(agent_run, **estimate_output_distrs_kwargs)
56
+ if result is None:
57
+ pbar.update(1)
58
+ return
59
+
60
+ distrs, metadata = result
61
+ results[agent_run.id] = MultiReflectRollouts.model_validate(
62
+ {
63
+ "agent_run_id": agent_run.id,
64
+ "distributions": distrs,
65
+ **metadata,
66
+ }
67
+ )
68
+ await _persist()
69
+ pbar.update(1)
70
+
71
+ async with anyio.create_task_group() as tg_outer:
72
+ for agent_run in agent_runs:
73
+ tg_outer.start_soon(_execute_for_agent_run, agent_run)
74
+
75
+ pbar.close()
76
+
77
+ return results