docent-python 0.1.41a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/__init__.py +4 -0
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +134 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/_log_util/__init__.py +3 -0
- docent/_log_util/logger.py +141 -0
- docent/data_models/__init__.py +14 -0
- docent/data_models/_tiktoken_util.py +91 -0
- docent/data_models/agent_run.py +473 -0
- docent/data_models/chat/__init__.py +37 -0
- docent/data_models/chat/content.py +56 -0
- docent/data_models/chat/message.py +191 -0
- docent/data_models/chat/tool.py +109 -0
- docent/data_models/citation.py +187 -0
- docent/data_models/formatted_objects.py +84 -0
- docent/data_models/judge.py +17 -0
- docent/data_models/metadata_util.py +16 -0
- docent/data_models/regex.py +56 -0
- docent/data_models/transcript.py +305 -0
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +320 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +68 -0
- docent/judges/util/voting.py +139 -0
- docent/loaders/load_inspect.py +215 -0
- docent/py.typed +0 -0
- docent/samples/__init__.py +3 -0
- docent/samples/load.py +9 -0
- docent/samples/log.eval +0 -0
- docent/samples/tb_airline.json +1 -0
- docent/sdk/__init__.py +0 -0
- docent/sdk/agent_run_writer.py +317 -0
- docent/sdk/client.py +1186 -0
- docent/sdk/llm_context.py +432 -0
- docent/trace.py +2741 -0
- docent/trace_temp.py +1086 -0
- docent_python-0.1.41a0.dist-info/METADATA +33 -0
- docent_python-0.1.41a0.dist-info/RECORD +59 -0
- docent_python-0.1.41a0.dist-info/WHEEL +4 -0
- docent_python-0.1.41a0.dist-info/licenses/LICENSE.md +13 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import textwrap
|
|
3
|
+
from collections import deque
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, Literal, TypedDict, cast
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from pydantic import (
|
|
10
|
+
BaseModel,
|
|
11
|
+
Field,
|
|
12
|
+
PrivateAttr,
|
|
13
|
+
field_validator,
|
|
14
|
+
model_validator,
|
|
15
|
+
)
|
|
16
|
+
from pydantic_core import to_jsonable_python
|
|
17
|
+
|
|
18
|
+
from docent._log_util import get_logger
|
|
19
|
+
from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
|
|
20
|
+
from docent.data_models.metadata_util import dump_metadata
|
|
21
|
+
from docent.data_models.transcript import Transcript, TranscriptGroup
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FilterableField(TypedDict):
|
|
27
|
+
name: str
|
|
28
|
+
type: Literal["str", "bool", "int", "float"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AgentRun(BaseModel):
|
|
32
|
+
"""Represents a complete run of an agent with transcripts and metadata.
|
|
33
|
+
|
|
34
|
+
An AgentRun encapsulates the execution of an agent, storing all communication
|
|
35
|
+
transcripts and associated metadata. It must contain at least one transcript.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
id: Unique identifier for the agent run, auto-generated by default.
|
|
39
|
+
name: Optional human-readable name for the agent run.
|
|
40
|
+
description: Optional description of the agent run.
|
|
41
|
+
transcripts: List of Transcript objects.
|
|
42
|
+
transcript_groups: List of TranscriptGroup objects.
|
|
43
|
+
metadata: Additional structured metadata about the agent run as a JSON-serializable dictionary.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
47
|
+
name: str | None = None
|
|
48
|
+
description: str | None = None
|
|
49
|
+
|
|
50
|
+
transcripts: list[Transcript]
|
|
51
|
+
transcript_groups: list[TranscriptGroup] = Field(default_factory=list)
|
|
52
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
53
|
+
|
|
54
|
+
@field_validator("transcripts", mode="before")
|
|
55
|
+
@classmethod
|
|
56
|
+
def _validate_transcripts_type(cls, v: Any) -> Any:
|
|
57
|
+
if isinstance(v, dict):
|
|
58
|
+
logger.warning(
|
|
59
|
+
"dict[str, Transcript] for transcripts is deprecated. Use list[Transcript] instead."
|
|
60
|
+
)
|
|
61
|
+
v = cast(dict[str, Transcript], v)
|
|
62
|
+
return [Transcript.model_validate(t) for t in v.values()]
|
|
63
|
+
return v
|
|
64
|
+
|
|
65
|
+
@field_validator("transcript_groups", mode="before")
|
|
66
|
+
@classmethod
|
|
67
|
+
def _validate_transcript_groups_type(cls, v: Any) -> Any:
|
|
68
|
+
if isinstance(v, dict):
|
|
69
|
+
logger.warning(
|
|
70
|
+
"dict[str, TranscriptGroup] for transcript_groups is deprecated. Use list[TranscriptGroup] instead."
|
|
71
|
+
)
|
|
72
|
+
v = cast(dict[str, TranscriptGroup], v)
|
|
73
|
+
return [TranscriptGroup.model_validate(tg) for tg in v.values()]
|
|
74
|
+
return v
|
|
75
|
+
|
|
76
|
+
@model_validator(mode="after")
|
|
77
|
+
def _validate_transcripts_not_empty(self):
|
|
78
|
+
"""Validates that the agent run contains at least one transcript.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If the transcripts list is empty.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
AgentRun: The validated AgentRun instance.
|
|
85
|
+
"""
|
|
86
|
+
if len(self.transcripts) == 0:
|
|
87
|
+
raise ValueError("AgentRun must have at least one transcript")
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
def get_filterable_fields(self, max_depth: int = 1) -> list[FilterableField]:
|
|
91
|
+
"""Returns a list of all fields that can be used to filter the agent run,
|
|
92
|
+
by recursively exploring the model_dump() for singleton types in dictionaries.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
list[FilterableField]: A list of filterable fields, where each field is a
|
|
96
|
+
dictionary containing its 'name' (path) and 'type'.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
result: list[FilterableField] = []
|
|
100
|
+
|
|
101
|
+
def _explore_dict(d: dict[str, Any], prefix: str, depth: int):
|
|
102
|
+
nonlocal result
|
|
103
|
+
|
|
104
|
+
if depth > max_depth:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
for k, v in d.items():
|
|
108
|
+
if isinstance(v, (str, int, float, bool)):
|
|
109
|
+
result.append(
|
|
110
|
+
{
|
|
111
|
+
"name": f"{prefix}.{k}",
|
|
112
|
+
"type": cast(Literal["str", "bool", "int", "float"], type(v).__name__),
|
|
113
|
+
}
|
|
114
|
+
)
|
|
115
|
+
elif isinstance(v, dict):
|
|
116
|
+
_explore_dict(cast(dict[str, Any], v), f"{prefix}.{k}", depth + 1)
|
|
117
|
+
|
|
118
|
+
# Look at the agent run metadata
|
|
119
|
+
_explore_dict(to_jsonable_python(self.metadata), "metadata", 0)
|
|
120
|
+
# Look at the transcript metadata
|
|
121
|
+
# TODO(mengk): restore this later when we have the ability to integrate with SQL.
|
|
122
|
+
# for t_id, t in self.transcripts.items():
|
|
123
|
+
# _explore_dict(
|
|
124
|
+
# t.metadata.model_dump(strip_internal_fields=True), f"transcript.{t_id}.metadata", 0
|
|
125
|
+
# )
|
|
126
|
+
|
|
127
|
+
# Append the text field
|
|
128
|
+
result.append({"name": "agent_run_id", "type": "str"})
|
|
129
|
+
result.append({"name": "text", "type": "str"})
|
|
130
|
+
|
|
131
|
+
return result
|
|
132
|
+
|
|
133
|
+
######################
|
|
134
|
+
# Converting to text #
|
|
135
|
+
######################
|
|
136
|
+
|
|
137
|
+
def _to_text_impl(self, token_limit: int = sys.maxsize) -> list[str]:
|
|
138
|
+
"""
|
|
139
|
+
Core implementation for converting agent run to text representation.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
token_limit: Maximum tokens per returned string under the GPT-4 tokenization scheme
|
|
143
|
+
use_blocks: If True, use individual message blocks. If False, use action units.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
List of strings, each at most token_limit tokens
|
|
147
|
+
"""
|
|
148
|
+
# Generate transcript strings using appropriate method
|
|
149
|
+
transcript_strs: list[str] = []
|
|
150
|
+
for i, t in enumerate(self.transcripts):
|
|
151
|
+
transcript_content = t.to_str(
|
|
152
|
+
token_limit=sys.maxsize,
|
|
153
|
+
transcript_idx=i,
|
|
154
|
+
)[0]
|
|
155
|
+
transcript_strs.append(f"<transcript>\n{transcript_content}\n</transcript>")
|
|
156
|
+
|
|
157
|
+
transcripts_str = "\n\n".join(transcript_strs)
|
|
158
|
+
|
|
159
|
+
# Gather metadata
|
|
160
|
+
metadata_obj = to_jsonable_python(self.metadata)
|
|
161
|
+
if self.name is not None:
|
|
162
|
+
metadata_obj["name"] = self.name
|
|
163
|
+
if self.description is not None:
|
|
164
|
+
metadata_obj["description"] = self.description
|
|
165
|
+
|
|
166
|
+
yaml_width = float("inf")
|
|
167
|
+
transcripts_str = (
|
|
168
|
+
f"Here is a complete agent run for analysis purposes only:\n{transcripts_str}\n\n"
|
|
169
|
+
)
|
|
170
|
+
metadata_str = f"Metadata about the complete agent run:\n<agent run metadata>\n{yaml.dump(metadata_obj, width=yaml_width)}\n</agent run metadata>"
|
|
171
|
+
|
|
172
|
+
if token_limit == sys.maxsize:
|
|
173
|
+
return [f"{transcripts_str}" f"{metadata_str}"]
|
|
174
|
+
|
|
175
|
+
# Compute message length; if fits, return the full transcript and metadata
|
|
176
|
+
transcript_str_tokens = get_token_count(transcripts_str)
|
|
177
|
+
metadata_str_tokens = get_token_count(metadata_str)
|
|
178
|
+
if transcript_str_tokens + metadata_str_tokens <= token_limit:
|
|
179
|
+
return [f"{transcripts_str}" f"{metadata_str}"]
|
|
180
|
+
|
|
181
|
+
# Otherwise, split up the transcript and metadata into chunks
|
|
182
|
+
else:
|
|
183
|
+
results: list[str] = []
|
|
184
|
+
transcript_token_counts = [get_token_count(t) for t in transcript_strs]
|
|
185
|
+
ranges = group_messages_into_ranges(
|
|
186
|
+
transcript_token_counts, metadata_str_tokens, token_limit - 50
|
|
187
|
+
)
|
|
188
|
+
for msg_range in ranges:
|
|
189
|
+
if msg_range.include_metadata:
|
|
190
|
+
cur_transcript_str = "\n\n".join(
|
|
191
|
+
transcript_strs[msg_range.start : msg_range.end]
|
|
192
|
+
)
|
|
193
|
+
results.append(
|
|
194
|
+
f"Here is a partial agent run for analysis purposes only:\n{cur_transcript_str}"
|
|
195
|
+
f"{metadata_str}"
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
assert (
|
|
199
|
+
msg_range.end == msg_range.start + 1
|
|
200
|
+
), "Ranges without metadata should be a single message"
|
|
201
|
+
t = self.transcripts[msg_range.start]
|
|
202
|
+
if msg_range.num_tokens < token_limit - 50:
|
|
203
|
+
transcript = (
|
|
204
|
+
f"<transcript>\n{t.to_str(token_limit=sys.maxsize)[0]}\n</transcript>"
|
|
205
|
+
)
|
|
206
|
+
result = (
|
|
207
|
+
f"Here is a partial agent run for analysis purposes only:\n{transcript}"
|
|
208
|
+
)
|
|
209
|
+
results.append(result)
|
|
210
|
+
else:
|
|
211
|
+
transcript_fragments: list[str] = t.to_str(
|
|
212
|
+
token_limit=token_limit - 50,
|
|
213
|
+
)
|
|
214
|
+
for fragment in transcript_fragments:
|
|
215
|
+
result = f"<transcript>\n{fragment}\n</transcript>"
|
|
216
|
+
result = (
|
|
217
|
+
f"Here is a partial agent run for analysis purposes only:\n{result}"
|
|
218
|
+
)
|
|
219
|
+
results.append(result)
|
|
220
|
+
return results
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def text(self) -> str:
|
|
224
|
+
"""Concatenates all transcript texts with double newlines as separators.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
str: A string representation of all transcripts.
|
|
228
|
+
"""
|
|
229
|
+
return self._to_text_impl(token_limit=sys.maxsize)[0]
|
|
230
|
+
|
|
231
|
+
##############################
|
|
232
|
+
# New text rendering methods #
|
|
233
|
+
##############################
|
|
234
|
+
|
|
235
|
+
# Transcript ID -> Transcript
|
|
236
|
+
_transcript_dict: dict[str, Transcript] | None = PrivateAttr(default=None)
|
|
237
|
+
# Transcript Group ID -> Transcript Group
|
|
238
|
+
_transcript_group_dict: dict[str, TranscriptGroup] | None = PrivateAttr(default=None)
|
|
239
|
+
# Canonical tree cache keyed by full_tree flag
|
|
240
|
+
_canonical_tree_cache: dict[bool, dict[str | None, list[tuple[Literal["t", "tg"], str]]]] = (
|
|
241
|
+
PrivateAttr(default_factory=dict)
|
|
242
|
+
)
|
|
243
|
+
# Transcript IDs (depth-first) cache keyed by full_tree flag
|
|
244
|
+
_transcript_ids_ordered_cache: dict[bool, list[str]] = PrivateAttr(default_factory=dict)
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def transcript_dict(self) -> dict[str, Transcript]:
|
|
248
|
+
"""Lazily compute and cache a mapping from transcript ID to Transcript."""
|
|
249
|
+
if self._transcript_dict is None:
|
|
250
|
+
self._transcript_dict = {t.id: t for t in self.transcripts}
|
|
251
|
+
return self._transcript_dict
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def transcript_group_dict(self) -> dict[str, TranscriptGroup]:
|
|
255
|
+
"""Lazily compute and cache a mapping from transcript group ID to TranscriptGroup."""
|
|
256
|
+
if self._transcript_group_dict is None:
|
|
257
|
+
self._transcript_group_dict = {tg.id: tg for tg in self.transcript_groups}
|
|
258
|
+
return self._transcript_group_dict
|
|
259
|
+
|
|
260
|
+
def _invalidate_caches(self) -> None:
|
|
261
|
+
"""Reset cached lookups after mutating transcripts or transcript groups."""
|
|
262
|
+
self._transcript_dict = None
|
|
263
|
+
self._transcript_group_dict = None
|
|
264
|
+
self._canonical_tree_cache.clear()
|
|
265
|
+
self._transcript_ids_ordered_cache.clear()
|
|
266
|
+
|
|
267
|
+
def get_canonical_tree(
|
|
268
|
+
self, full_tree: bool = False
|
|
269
|
+
) -> dict[str | None, list[tuple[Literal["t", "tg"], str]]]:
|
|
270
|
+
"""Compute and cache the canonical, sorted transcript group tree.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
full_tree: If True, include all transcript groups regardless of whether
|
|
274
|
+
they contain transcripts. If False, include only the minimal tree
|
|
275
|
+
that connects relevant groups and transcripts.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Canonical tree mapping parent group id (or "__global_root") to a list of
|
|
279
|
+
children (type, id) tuples sorted by creation time.
|
|
280
|
+
"""
|
|
281
|
+
if (
|
|
282
|
+
full_tree not in self._canonical_tree_cache
|
|
283
|
+
or full_tree not in self._transcript_ids_ordered_cache
|
|
284
|
+
):
|
|
285
|
+
canonical_tree, transcript_idx_map = self._build_canonical_tree(full_tree=full_tree)
|
|
286
|
+
self._canonical_tree_cache[full_tree] = canonical_tree
|
|
287
|
+
self._transcript_ids_ordered_cache[full_tree] = list(transcript_idx_map.keys())
|
|
288
|
+
return self._canonical_tree_cache[full_tree]
|
|
289
|
+
|
|
290
|
+
def get_transcript_ids_ordered(self, full_tree: bool = False) -> list[str]:
|
|
291
|
+
"""Compute and cache the depth-first transcript id ordering.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
full_tree: Whether to compute based on the full tree or the minimal tree.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
List of transcript ids in depth-first order.
|
|
298
|
+
"""
|
|
299
|
+
if (
|
|
300
|
+
full_tree not in self._transcript_ids_ordered_cache
|
|
301
|
+
or full_tree not in self._canonical_tree_cache
|
|
302
|
+
):
|
|
303
|
+
canonical_tree, transcript_idx_map = self._build_canonical_tree(full_tree=full_tree)
|
|
304
|
+
self._canonical_tree_cache[full_tree] = canonical_tree
|
|
305
|
+
self._transcript_ids_ordered_cache[full_tree] = list(transcript_idx_map.keys())
|
|
306
|
+
return self._transcript_ids_ordered_cache[full_tree]
|
|
307
|
+
|
|
308
|
+
def _build_canonical_tree(self, full_tree: bool = False):
|
|
309
|
+
t_dict = self.transcript_dict
|
|
310
|
+
tg_dict = self.transcript_group_dict
|
|
311
|
+
|
|
312
|
+
# Find all transcript groups that have direct transcript children
|
|
313
|
+
# Also keep track of transcripts that are not in a group
|
|
314
|
+
tgs_to_transcripts: dict[str, set[str]] = {}
|
|
315
|
+
for transcript in t_dict.values():
|
|
316
|
+
if transcript.transcript_group_id is None:
|
|
317
|
+
tgs_to_transcripts.setdefault("__global_root", set()).add(transcript.id)
|
|
318
|
+
else:
|
|
319
|
+
tgs_to_transcripts.setdefault(transcript.transcript_group_id, set()).add(
|
|
320
|
+
transcript.id
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# tg_tree maps from parent -> children. A child can be a group or a transcript.
|
|
324
|
+
# A parent must be a group (or None, for transcripts that are not in a group).
|
|
325
|
+
tg_tree: dict[str, set[tuple[Literal["t", "tg"], str]]] = {}
|
|
326
|
+
|
|
327
|
+
if full_tree:
|
|
328
|
+
for tg_id, tg in tg_dict.items():
|
|
329
|
+
tg_tree.setdefault(tg.parent_transcript_group_id or "__global_root", set()).add(
|
|
330
|
+
("tg", tg_id)
|
|
331
|
+
)
|
|
332
|
+
for t_id in tgs_to_transcripts.get(tg_id, []):
|
|
333
|
+
tg_tree.setdefault(tg_id, set()).add(("t", t_id))
|
|
334
|
+
for t_id, t in t_dict.items():
|
|
335
|
+
tg_tree.setdefault(t.transcript_group_id or "__global_root", set()).add(("t", t_id))
|
|
336
|
+
else:
|
|
337
|
+
# Initialize q with "important" tgs
|
|
338
|
+
q, seen = deque(tgs_to_transcripts.keys()), set(tgs_to_transcripts.keys())
|
|
339
|
+
|
|
340
|
+
# Do an "upwards BFS" from leaves up to the root. Builds a tree of only relevant nodes.
|
|
341
|
+
while q:
|
|
342
|
+
u_id = q.popleft()
|
|
343
|
+
u = tg_dict.get(u_id) # None if __global_root
|
|
344
|
+
|
|
345
|
+
# Add the transcripts under this tg
|
|
346
|
+
for t_id in tgs_to_transcripts.get(u_id, []):
|
|
347
|
+
tg_tree.setdefault(u_id, set()).add(("t", t_id))
|
|
348
|
+
|
|
349
|
+
# Add an edge from the parent
|
|
350
|
+
if u is not None:
|
|
351
|
+
par_id = u.parent_transcript_group_id or "__global_root"
|
|
352
|
+
# Mark u as a child of par
|
|
353
|
+
tg_tree.setdefault(par_id, set()).add(("tg", u_id))
|
|
354
|
+
# If we haven't investigated the parent before, add to q
|
|
355
|
+
if par_id not in seen:
|
|
356
|
+
q.append(par_id)
|
|
357
|
+
seen.add(par_id)
|
|
358
|
+
|
|
359
|
+
# For each node, sort by created_at timestamp
|
|
360
|
+
|
|
361
|
+
def _cmp(element: tuple[Literal["t", "tg"], str]) -> datetime:
|
|
362
|
+
obj_type, obj_id = element
|
|
363
|
+
if obj_type == "tg":
|
|
364
|
+
return tg_dict[obj_id].created_at or datetime.max
|
|
365
|
+
else:
|
|
366
|
+
return t_dict[obj_id].created_at or datetime.max
|
|
367
|
+
|
|
368
|
+
c_tree: dict[str | None, list[tuple[Literal["t", "tg"], str]]] = {}
|
|
369
|
+
for tg_id in tg_tree:
|
|
370
|
+
children_ids = list(set(tg_tree[tg_id]))
|
|
371
|
+
sorted_children_ids = sorted(children_ids, key=_cmp)
|
|
372
|
+
c_tree[tg_id] = sorted_children_ids
|
|
373
|
+
|
|
374
|
+
# Compute transcript indices as the depth-first traversal index
|
|
375
|
+
transcript_idx_map: dict[str, int] = {}
|
|
376
|
+
|
|
377
|
+
def _assign_transcript_indices(cur_tg_id: str, next_idx: int) -> int:
|
|
378
|
+
children = c_tree.get(cur_tg_id, [])
|
|
379
|
+
for child_type, child_id in children:
|
|
380
|
+
if child_type == "tg":
|
|
381
|
+
next_idx = _assign_transcript_indices(child_id, next_idx)
|
|
382
|
+
else:
|
|
383
|
+
transcript_idx_map[child_id] = next_idx
|
|
384
|
+
next_idx += 1
|
|
385
|
+
return next_idx
|
|
386
|
+
|
|
387
|
+
_assign_transcript_indices("__global_root", 0)
|
|
388
|
+
|
|
389
|
+
return c_tree, transcript_idx_map
|
|
390
|
+
|
|
391
|
+
def delete_transcript_group_subtree(self, transcript_group_id: str) -> None:
|
|
392
|
+
"""Delete a transcript group and all descendant groups/transcripts using the canonical tree."""
|
|
393
|
+
if transcript_group_id == "__global_root":
|
|
394
|
+
raise ValueError("Cannot delete the global root sentinel")
|
|
395
|
+
if transcript_group_id not in self.transcript_group_dict:
|
|
396
|
+
raise ValueError(
|
|
397
|
+
f"Transcript group '{transcript_group_id}' does not exist on this run."
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
canonical_tree = self.get_canonical_tree(full_tree=True)
|
|
401
|
+
groups_to_delete: set[str] = set()
|
|
402
|
+
transcripts_to_delete: set[str] = set()
|
|
403
|
+
|
|
404
|
+
queue: deque[str] = deque([transcript_group_id])
|
|
405
|
+
while queue:
|
|
406
|
+
current_group = queue.popleft()
|
|
407
|
+
groups_to_delete.add(current_group)
|
|
408
|
+
for child_type, child_id in canonical_tree.get(current_group, []):
|
|
409
|
+
if child_type == "tg":
|
|
410
|
+
queue.append(child_id)
|
|
411
|
+
else:
|
|
412
|
+
transcripts_to_delete.add(child_id)
|
|
413
|
+
|
|
414
|
+
if groups_to_delete:
|
|
415
|
+
self.transcript_groups = [
|
|
416
|
+
tg for tg in self.transcript_groups if tg.id not in groups_to_delete
|
|
417
|
+
]
|
|
418
|
+
if transcripts_to_delete:
|
|
419
|
+
self.transcripts = [t for t in self.transcripts if t.id not in transcripts_to_delete]
|
|
420
|
+
|
|
421
|
+
self._invalidate_caches()
|
|
422
|
+
|
|
423
|
+
def to_text_new(
|
|
424
|
+
self,
|
|
425
|
+
agent_run_alias: int | str = 0,
|
|
426
|
+
t_idx_map: dict[str, int] | None = None,
|
|
427
|
+
indent: int = 0,
|
|
428
|
+
full_tree: bool = False,
|
|
429
|
+
):
|
|
430
|
+
if isinstance(agent_run_alias, int):
|
|
431
|
+
agent_run_alias = f"R{agent_run_alias}"
|
|
432
|
+
|
|
433
|
+
c_tree = self.get_canonical_tree(full_tree=full_tree)
|
|
434
|
+
t_ids_ordered = self.get_transcript_ids_ordered(full_tree=full_tree)
|
|
435
|
+
if t_idx_map is None:
|
|
436
|
+
t_idx_map = {t_id: i for i, t_id in enumerate(t_ids_ordered)}
|
|
437
|
+
t_dict = self.transcript_dict
|
|
438
|
+
tg_dict = self.transcript_group_dict
|
|
439
|
+
|
|
440
|
+
# Traverse the tree and render the string
|
|
441
|
+
def _recurse(tg_id: str) -> str:
|
|
442
|
+
children_ids = c_tree.get(tg_id, [])
|
|
443
|
+
children_texts: list[str] = []
|
|
444
|
+
for child_type, child_id in children_ids:
|
|
445
|
+
if child_type == "tg":
|
|
446
|
+
children_texts.append(_recurse(child_id))
|
|
447
|
+
else:
|
|
448
|
+
cur_text = t_dict[child_id].to_text_new(
|
|
449
|
+
transcript_alias=t_idx_map[child_id],
|
|
450
|
+
indent=indent,
|
|
451
|
+
)
|
|
452
|
+
children_texts.append(cur_text)
|
|
453
|
+
children_text = "\n".join(children_texts)
|
|
454
|
+
|
|
455
|
+
# No wrapper for global root
|
|
456
|
+
if tg_id == "__global_root":
|
|
457
|
+
return children_text
|
|
458
|
+
# Delegate rendering to TranscriptGroup
|
|
459
|
+
else:
|
|
460
|
+
tg = tg_dict[tg_id]
|
|
461
|
+
return tg.to_text_new(children_text=children_text, indent=indent)
|
|
462
|
+
|
|
463
|
+
text = _recurse("__global_root")
|
|
464
|
+
|
|
465
|
+
# Append agent run metadata below the full content
|
|
466
|
+
metadata_text = dump_metadata(self.metadata)
|
|
467
|
+
if metadata_text is not None:
|
|
468
|
+
if indent > 0:
|
|
469
|
+
metadata_text = textwrap.indent(metadata_text, " " * indent)
|
|
470
|
+
metadata_alias = f"{agent_run_alias}M"
|
|
471
|
+
text += f"\n<|agent run metadata {metadata_alias}|>\n{metadata_text}\n</|agent run metadata {metadata_alias}|>"
|
|
472
|
+
|
|
473
|
+
return f"<|agent run {agent_run_alias}|>\n{text}\n</|agent run {agent_run_alias}|>\n"
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from docent.data_models.chat.content import Content, ContentReasoning, ContentText
|
|
2
|
+
from docent.data_models.chat.message import (
|
|
3
|
+
AssistantMessage,
|
|
4
|
+
ChatMessage,
|
|
5
|
+
DocentAssistantMessage,
|
|
6
|
+
DocentChatMessage,
|
|
7
|
+
SystemMessage,
|
|
8
|
+
ToolMessage,
|
|
9
|
+
UserMessage,
|
|
10
|
+
parse_chat_message,
|
|
11
|
+
parse_docent_chat_message,
|
|
12
|
+
)
|
|
13
|
+
from docent.data_models.chat.tool import (
|
|
14
|
+
ToolCall,
|
|
15
|
+
ToolCallContent,
|
|
16
|
+
ToolInfo,
|
|
17
|
+
ToolParams,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"ChatMessage",
|
|
22
|
+
"DocentChatMessage",
|
|
23
|
+
"AssistantMessage",
|
|
24
|
+
"DocentAssistantMessage",
|
|
25
|
+
"SystemMessage",
|
|
26
|
+
"ToolMessage",
|
|
27
|
+
"UserMessage",
|
|
28
|
+
"Content",
|
|
29
|
+
"ContentReasoning",
|
|
30
|
+
"ContentText",
|
|
31
|
+
"ToolCall",
|
|
32
|
+
"ToolCallContent",
|
|
33
|
+
"ToolInfo",
|
|
34
|
+
"ToolParams",
|
|
35
|
+
"parse_chat_message",
|
|
36
|
+
"parse_docent_chat_message",
|
|
37
|
+
]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Discriminator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseContent(BaseModel):
|
|
7
|
+
"""Base class for all content types in chat messages.
|
|
8
|
+
|
|
9
|
+
Provides the foundation for different content types with a discriminator field.
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
type: The content type identifier, used for discriminating between content types.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
type: Literal["text", "reasoning", "image", "audio", "video"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ContentText(BaseContent):
|
|
19
|
+
"""Text content for chat messages.
|
|
20
|
+
|
|
21
|
+
Represents plain text content in a chat message.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
type: Fixed as "text" to identify this content type.
|
|
25
|
+
text: The actual text content.
|
|
26
|
+
refusal: Optional flag indicating if this is a refusal message.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
type: Literal["text"] = "text" # type: ignore
|
|
30
|
+
text: str
|
|
31
|
+
refusal: bool | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ContentReasoning(BaseContent):
|
|
35
|
+
"""Reasoning content for chat messages.
|
|
36
|
+
|
|
37
|
+
Represents reasoning or thought process content in a chat message.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
type: Fixed as "reasoning" to identify this content type.
|
|
41
|
+
reasoning: The actual reasoning text.
|
|
42
|
+
signature: Optional signature associated with the reasoning.
|
|
43
|
+
redacted: Flag indicating if the reasoning has been redacted.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
type: Literal["reasoning"] = "reasoning" # type: ignore
|
|
47
|
+
reasoning: str
|
|
48
|
+
signature: str | None = None
|
|
49
|
+
redacted: bool = False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Content type discriminated union
|
|
53
|
+
Content = Annotated[ContentText | ContentReasoning, Discriminator("type")]
|
|
54
|
+
"""Discriminated union of possible content types using the 'type' field.
|
|
55
|
+
Can be either ContentText or ContentReasoning.
|
|
56
|
+
"""
|