docent-python 0.1.3a0__py3-none-any.whl → 0.1.5a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/data_models/__init__.py +2 -9
- docent/data_models/agent_run.py +30 -20
- docent/data_models/metadata.py +229 -229
- docent/data_models/transcript.py +56 -16
- docent/loaders/load_inspect.py +37 -25
- docent/sdk/client.py +33 -23
- docent/trace.py +868 -304
- docent/trace_temp.py +1086 -0
- {docent_python-0.1.3a0.dist-info → docent_python-0.1.5a0.dist-info}/METADATA +1 -2
- {docent_python-0.1.3a0.dist-info → docent_python-0.1.5a0.dist-info}/RECORD +12 -12
- docent/trace_alt.py +0 -497
- {docent_python-0.1.3a0.dist-info → docent_python-0.1.5a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.3a0.dist-info → docent_python-0.1.5a0.dist-info}/licenses/LICENSE.md +0 -0
docent/data_models/transcript.py
CHANGED
|
@@ -11,7 +11,6 @@ from docent.data_models._tiktoken_util import (
|
|
|
11
11
|
truncate_to_token_limit,
|
|
12
12
|
)
|
|
13
13
|
from docent.data_models.chat import AssistantMessage, ChatMessage, ContentReasoning
|
|
14
|
-
from docent.data_models.metadata import BaseMetadata
|
|
15
14
|
|
|
16
15
|
# Template for formatting individual transcript blocks
|
|
17
16
|
TRANSCRIPT_BLOCK_TEMPLATE = """
|
|
@@ -63,6 +62,53 @@ def format_chat_message(
|
|
|
63
62
|
)
|
|
64
63
|
|
|
65
64
|
|
|
65
|
+
class TranscriptGroup(BaseModel):
|
|
66
|
+
"""Represents a group of transcripts that are logically related.
|
|
67
|
+
|
|
68
|
+
A transcript group can contain multiple transcripts and can have a hierarchical
|
|
69
|
+
structure with parent groups. This is useful for organizing transcripts into
|
|
70
|
+
logical units like experiments, tasks, or sessions.
|
|
71
|
+
|
|
72
|
+
Attributes:
|
|
73
|
+
id: Unique identifier for the transcript group, auto-generated by default.
|
|
74
|
+
name: Optional human-readable name for the transcript group.
|
|
75
|
+
description: Optional description of the transcript group.
|
|
76
|
+
parent_transcript_group_id: Optional ID of the parent transcript group.
|
|
77
|
+
metadata: Additional structured metadata about the transcript group.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
81
|
+
name: str | None = None
|
|
82
|
+
description: str | None = None
|
|
83
|
+
parent_transcript_group_id: str | None = None
|
|
84
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
85
|
+
|
|
86
|
+
@field_serializer("metadata")
|
|
87
|
+
def serialize_metadata(self, metadata: dict[str, Any], _info: Any) -> dict[str, Any]:
|
|
88
|
+
"""
|
|
89
|
+
Custom serializer for the metadata field so the internal fields are explicitly preserved.
|
|
90
|
+
"""
|
|
91
|
+
return fake_model_dump(metadata)
|
|
92
|
+
|
|
93
|
+
@field_validator("metadata", mode="before")
|
|
94
|
+
@classmethod
|
|
95
|
+
def _validate_metadata_type(cls, v: Any) -> Any:
|
|
96
|
+
if v is not None and not isinstance(v, dict):
|
|
97
|
+
raise ValueError(f"metadata must be a dictionary, got {type(v).__name__}")
|
|
98
|
+
return v # type: ignore
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def fake_model_dump(obj: dict[str, Any]) -> dict[str, Any]:
|
|
102
|
+
"""
|
|
103
|
+
Emulate the action of pydantic.model_dump() for non-pydantic objects (to handle nested values)
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
class _FakeModel(BaseModel):
|
|
107
|
+
data: dict[str, Any]
|
|
108
|
+
|
|
109
|
+
return _FakeModel(data=obj).model_dump()["data"]
|
|
110
|
+
|
|
111
|
+
|
|
66
112
|
class Transcript(BaseModel):
|
|
67
113
|
"""Represents a transcript of messages in a conversation with an AI agent.
|
|
68
114
|
|
|
@@ -74,6 +120,7 @@ class Transcript(BaseModel):
|
|
|
74
120
|
id: Unique identifier for the transcript, auto-generated by default.
|
|
75
121
|
name: Optional human-readable name for the transcript.
|
|
76
122
|
description: Optional description of the transcript.
|
|
123
|
+
transcript_group_id: Optional ID of the transcript group this transcript belongs to.
|
|
77
124
|
messages: List of chat messages in the transcript.
|
|
78
125
|
metadata: Additional structured metadata about the transcript.
|
|
79
126
|
"""
|
|
@@ -81,27 +128,25 @@ class Transcript(BaseModel):
|
|
|
81
128
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
82
129
|
name: str | None = None
|
|
83
130
|
description: str | None = None
|
|
131
|
+
transcript_group_id: str | None = None
|
|
84
132
|
|
|
85
133
|
messages: list[ChatMessage]
|
|
86
|
-
metadata:
|
|
87
|
-
|
|
134
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
88
135
|
_units_of_action: list[list[int]] | None = PrivateAttr(default=None)
|
|
89
136
|
|
|
90
137
|
@field_serializer("metadata")
|
|
91
|
-
def serialize_metadata(self, metadata:
|
|
138
|
+
def serialize_metadata(self, metadata: dict[str, Any], _info: Any) -> dict[str, Any]:
|
|
92
139
|
"""
|
|
93
140
|
Custom serializer for the metadata field so the internal fields are explicitly preserved.
|
|
94
141
|
"""
|
|
95
|
-
return metadata
|
|
142
|
+
return fake_model_dump(metadata)
|
|
96
143
|
|
|
97
144
|
@field_validator("metadata", mode="before")
|
|
98
145
|
@classmethod
|
|
99
146
|
def _validate_metadata_type(cls, v: Any) -> Any:
|
|
100
|
-
if v is not None and not isinstance(v,
|
|
101
|
-
raise ValueError(
|
|
102
|
-
|
|
103
|
-
)
|
|
104
|
-
return v
|
|
147
|
+
if v is not None and not isinstance(v, dict):
|
|
148
|
+
raise ValueError(f"metadata must be a dict, got {type(v).__name__}")
|
|
149
|
+
return v # type: ignore
|
|
105
150
|
|
|
106
151
|
@property
|
|
107
152
|
def units_of_action(self) -> list[list[int]]:
|
|
@@ -297,12 +342,7 @@ class Transcript(BaseModel):
|
|
|
297
342
|
blocks_str = "\n".join(au_blocks)
|
|
298
343
|
|
|
299
344
|
# Gather metadata
|
|
300
|
-
metadata_obj = self.metadata
|
|
301
|
-
# Add the field descriptions if they exist
|
|
302
|
-
metadata_obj = {
|
|
303
|
-
(f"{k} ({d})" if (d := self.metadata.get_field_description(k)) is not None else k): v
|
|
304
|
-
for k, v in metadata_obj.items()
|
|
305
|
-
}
|
|
345
|
+
metadata_obj = fake_model_dump(self.metadata)
|
|
306
346
|
|
|
307
347
|
yaml_width = float("inf")
|
|
308
348
|
block_str = f"<blocks>\n{blocks_str}\n</blocks>\n"
|
docent/loaders/load_inspect.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
from inspect_ai.log import EvalLog
|
|
2
4
|
from inspect_ai.scorer import CORRECT, INCORRECT, NOANSWER, PARTIAL, Score
|
|
3
5
|
|
|
4
|
-
from docent.data_models import AgentRun,
|
|
6
|
+
from docent.data_models import AgentRun, Transcript
|
|
5
7
|
from docent.data_models.chat import parse_chat_message
|
|
6
8
|
|
|
7
9
|
|
|
8
|
-
def _normalize_inspect_score(score: Score) ->
|
|
10
|
+
def _normalize_inspect_score(score: Score) -> Any:
|
|
9
11
|
"""
|
|
10
12
|
Normalize an inspect score to a float. This implements the same logic as inspect_ai.scorer._metric.value_to_float, but fails more conspicuously.
|
|
11
13
|
|
|
@@ -16,30 +18,39 @@ def _normalize_inspect_score(score: Score) -> float | None:
|
|
|
16
18
|
The normalized score as a float, or None if the score is not a valid value.
|
|
17
19
|
"""
|
|
18
20
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
value
|
|
21
|
+
def _leaf_normalize(value: int | float | bool | str | None) -> float | str | None:
|
|
22
|
+
if value is None:
|
|
23
|
+
return None
|
|
24
|
+
if isinstance(value, int | float | bool):
|
|
25
|
+
return float(value)
|
|
26
|
+
if value == CORRECT:
|
|
27
|
+
return 1.0
|
|
28
|
+
if value == PARTIAL:
|
|
29
|
+
return 0.5
|
|
30
|
+
if value in [INCORRECT, NOANSWER]:
|
|
31
|
+
return 0
|
|
32
|
+
value = str(value).lower()
|
|
29
33
|
if value in ["yes", "true"]:
|
|
30
34
|
return 1.0
|
|
31
|
-
|
|
35
|
+
if value in ["no", "false"]:
|
|
32
36
|
return 0.0
|
|
33
|
-
|
|
37
|
+
if value.replace(".", "").isnumeric():
|
|
34
38
|
return float(value)
|
|
39
|
+
return value
|
|
35
40
|
|
|
36
|
-
|
|
41
|
+
if isinstance(score.value, int | float | bool | str):
|
|
42
|
+
return _leaf_normalize(score.value)
|
|
43
|
+
if isinstance(score.value, list):
|
|
44
|
+
return [_leaf_normalize(v) for v in score.value]
|
|
45
|
+
assert isinstance(score.value, dict), "Inspect score must be leaf value, list, or dict"
|
|
46
|
+
return {k: _leaf_normalize(v) for k, v in score.value.items()}
|
|
37
47
|
|
|
38
48
|
|
|
39
49
|
def load_inspect_log(log: EvalLog) -> list[AgentRun]:
|
|
40
50
|
if log.samples is None:
|
|
41
51
|
return []
|
|
42
52
|
|
|
53
|
+
# TODO(vincent): fix this
|
|
43
54
|
agent_runs: list[AgentRun] = []
|
|
44
55
|
|
|
45
56
|
for s in log.samples:
|
|
@@ -51,22 +62,23 @@ def load_inspect_log(log: EvalLog) -> list[AgentRun]:
|
|
|
51
62
|
else:
|
|
52
63
|
sample_scores = {k: _normalize_inspect_score(v) for k, v in s.scores.items()}
|
|
53
64
|
|
|
54
|
-
metadata =
|
|
55
|
-
task_id
|
|
56
|
-
sample_id
|
|
57
|
-
epoch_id
|
|
58
|
-
model
|
|
59
|
-
additional_metadata
|
|
60
|
-
scores
|
|
65
|
+
metadata = {
|
|
66
|
+
"task_id": log.eval.task,
|
|
67
|
+
"sample_id": str(sample_id),
|
|
68
|
+
"epoch_id": epoch_id,
|
|
69
|
+
"model": log.eval.model,
|
|
70
|
+
"additional_metadata": s.metadata,
|
|
71
|
+
"scores": sample_scores,
|
|
61
72
|
# Scores could have answers, explanations, and other metadata besides the values we extract
|
|
62
|
-
scoring_metadata
|
|
63
|
-
|
|
73
|
+
"scoring_metadata": s.scores,
|
|
74
|
+
}
|
|
64
75
|
|
|
65
76
|
agent_runs.append(
|
|
66
77
|
AgentRun(
|
|
67
78
|
transcripts={
|
|
68
79
|
"main": Transcript(
|
|
69
|
-
messages=[parse_chat_message(m.model_dump()) for m in s.messages]
|
|
80
|
+
messages=[parse_chat_message(m.model_dump()) for m in s.messages],
|
|
81
|
+
metadata={},
|
|
70
82
|
)
|
|
71
83
|
},
|
|
72
84
|
metadata=metadata,
|
docent/sdk/client.py
CHANGED
|
@@ -197,75 +197,85 @@ class Docent:
|
|
|
197
197
|
return response.json()
|
|
198
198
|
|
|
199
199
|
def list_searches(self, collection_id: str) -> list[dict[str, Any]]:
|
|
200
|
-
"""List all
|
|
200
|
+
"""List all rubrics for a given collection.
|
|
201
201
|
|
|
202
202
|
Args:
|
|
203
203
|
collection_id: ID of the Collection.
|
|
204
204
|
|
|
205
205
|
Returns:
|
|
206
|
-
list: List of dictionaries containing
|
|
206
|
+
list: List of dictionaries containing rubric information.
|
|
207
207
|
|
|
208
208
|
Raises:
|
|
209
209
|
requests.exceptions.HTTPError: If the API request fails.
|
|
210
210
|
"""
|
|
211
|
-
url = f"{self._server_url}/{collection_id}/
|
|
211
|
+
url = f"{self._server_url}/rubric/{collection_id}/rubrics"
|
|
212
212
|
response = self._session.get(url)
|
|
213
213
|
response.raise_for_status()
|
|
214
214
|
return response.json()
|
|
215
215
|
|
|
216
|
-
def get_search_results(
|
|
217
|
-
|
|
218
|
-
|
|
216
|
+
def get_search_results(
|
|
217
|
+
self, collection_id: str, rubric_id: str, rubric_version: int
|
|
218
|
+
) -> list[dict[str, Any]]:
|
|
219
|
+
"""Get rubric results for a given collection, rubric and version.
|
|
219
220
|
|
|
220
221
|
Args:
|
|
221
222
|
collection_id: ID of the Collection.
|
|
222
|
-
|
|
223
|
+
rubric_id: The ID of the rubric to get results for.
|
|
224
|
+
rubric_version: The version of the rubric to get results for.
|
|
223
225
|
|
|
224
226
|
Returns:
|
|
225
|
-
list: List of dictionaries containing
|
|
227
|
+
list: List of dictionaries containing rubric result information.
|
|
226
228
|
|
|
227
229
|
Raises:
|
|
228
230
|
requests.exceptions.HTTPError: If the API request fails.
|
|
229
231
|
"""
|
|
230
|
-
url = f"{self._server_url}/{collection_id}/
|
|
231
|
-
response = self._session.
|
|
232
|
+
url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/results"
|
|
233
|
+
response = self._session.get(url, params={"rubric_version": rubric_version})
|
|
232
234
|
response.raise_for_status()
|
|
233
235
|
return response.json()
|
|
234
236
|
|
|
235
|
-
def list_search_clusters(
|
|
236
|
-
|
|
237
|
-
|
|
237
|
+
def list_search_clusters(
|
|
238
|
+
self, collection_id: str, rubric_id: str, rubric_version: int | None = None
|
|
239
|
+
) -> list[dict[str, Any]]:
|
|
240
|
+
"""List all centroids for a given collection and rubric.
|
|
238
241
|
|
|
239
242
|
Args:
|
|
240
243
|
collection_id: ID of the Collection.
|
|
241
|
-
|
|
244
|
+
rubric_id: The ID of the rubric to get centroids for.
|
|
245
|
+
rubric_version: Optional version of the rubric. If not provided, uses latest.
|
|
242
246
|
|
|
243
247
|
Returns:
|
|
244
|
-
list: List of dictionaries containing
|
|
248
|
+
list: List of dictionaries containing centroid information.
|
|
245
249
|
|
|
246
250
|
Raises:
|
|
247
251
|
requests.exceptions.HTTPError: If the API request fails.
|
|
248
252
|
"""
|
|
249
|
-
url = f"{self._server_url}/{collection_id}/
|
|
250
|
-
|
|
253
|
+
url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/centroids"
|
|
254
|
+
params: dict[str, int] = {}
|
|
255
|
+
if rubric_version is not None:
|
|
256
|
+
params["rubric_version"] = rubric_version
|
|
257
|
+
response = self._session.get(url, params=params)
|
|
251
258
|
response.raise_for_status()
|
|
252
259
|
return response.json()
|
|
253
260
|
|
|
254
|
-
def get_cluster_matches(
|
|
255
|
-
|
|
261
|
+
def get_cluster_matches(
|
|
262
|
+
self, collection_id: str, rubric_id: str, rubric_version: int
|
|
263
|
+
) -> list[dict[str, Any]]:
|
|
264
|
+
"""Get centroid assignments for a given rubric.
|
|
256
265
|
|
|
257
266
|
Args:
|
|
258
267
|
collection_id: ID of the Collection.
|
|
259
|
-
|
|
268
|
+
rubric_id: The ID of the rubric to get assignments for.
|
|
269
|
+
rubric_version: The version of the rubric to get assignments for.
|
|
260
270
|
|
|
261
271
|
Returns:
|
|
262
|
-
list: List of dictionaries containing
|
|
272
|
+
list: List of dictionaries containing centroid assignment information.
|
|
263
273
|
|
|
264
274
|
Raises:
|
|
265
275
|
requests.exceptions.HTTPError: If the API request fails.
|
|
266
276
|
"""
|
|
267
|
-
url = f"{self._server_url}/{collection_id}/
|
|
268
|
-
response = self._session.
|
|
277
|
+
url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/assignments"
|
|
278
|
+
response = self._session.get(url, params={"rubric_version": rubric_version})
|
|
269
279
|
response.raise_for_status()
|
|
270
280
|
return response.json()
|
|
271
281
|
|