docent-python 0.1.3a0__tar.gz → 0.1.5a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/.gitignore +3 -0
  2. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/PKG-INFO +1 -2
  3. docent_python-0.1.5a0/docent/data_models/__init__.py +12 -0
  4. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/agent_run.py +30 -20
  5. docent_python-0.1.5a0/docent/data_models/metadata.py +229 -0
  6. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/transcript.py +56 -16
  7. docent_python-0.1.5a0/docent/loaders/load_inspect.py +88 -0
  8. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/sdk/client.py +33 -23
  9. docent_python-0.1.5a0/docent/trace.py +1650 -0
  10. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/pyproject.toml +1 -2
  11. docent_python-0.1.5a0/uv.lock +954 -0
  12. docent_python-0.1.3a0/docent/data_models/__init__.py +0 -19
  13. docent_python-0.1.3a0/docent/data_models/metadata.py +0 -229
  14. docent_python-0.1.3a0/docent/loaders/load_inspect.py +0 -76
  15. docent_python-0.1.3a0/docent/trace_alt.py +0 -497
  16. docent_python-0.1.3a0/uv.lock +0 -2030
  17. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/LICENSE.md +0 -0
  18. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/README.md +0 -0
  19. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/__init__.py +0 -0
  20. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/_log_util/__init__.py +0 -0
  21. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/_log_util/logger.py +0 -0
  22. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/_tiktoken_util.py +0 -0
  23. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/chat/__init__.py +0 -0
  24. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/chat/content.py +0 -0
  25. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/chat/message.py +0 -0
  26. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/chat/tool.py +0 -0
  27. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/citation.py +0 -0
  28. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/regex.py +0 -0
  29. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/data_models/shared_types.py +0 -0
  30. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/py.typed +0 -0
  31. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/samples/__init__.py +0 -0
  32. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/samples/load.py +0 -0
  33. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/samples/log.eval +0 -0
  34. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/samples/tb_airline.json +0 -0
  35. {docent_python-0.1.3a0 → docent_python-0.1.5a0}/docent/sdk/__init__.py +0 -0
  36. /docent_python-0.1.3a0/docent/trace.py → /docent_python-0.1.5a0/docent/trace_temp.py +0 -0
@@ -192,3 +192,6 @@ personal/caden/*
192
192
  inspect_evals
193
193
 
194
194
  *.swp
195
+
196
+ # test data cache
197
+ data/cache
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: docent-python
3
- Version: 0.1.3a0
3
+ Version: 0.1.5a0
4
4
  Summary: Docent SDK
5
5
  Project-URL: Homepage, https://github.com/TransluceAI/docent
6
6
  Project-URL: Issues, https://github.com/TransluceAI/docent/issues
@@ -22,4 +22,3 @@ Requires-Dist: pydantic>=2.11.7
22
22
  Requires-Dist: pyyaml>=6.0.2
23
23
  Requires-Dist: tiktoken>=0.7.0
24
24
  Requires-Dist: tqdm>=4.67.1
25
- Requires-Dist: traceloop-sdk>=0.44.1
@@ -0,0 +1,12 @@
1
+ from docent.data_models.agent_run import AgentRun
2
+ from docent.data_models.citation import Citation
3
+ from docent.data_models.regex import RegexSnippet
4
+ from docent.data_models.transcript import Transcript, TranscriptGroup
5
+
6
+ __all__ = [
7
+ "AgentRun",
8
+ "Citation",
9
+ "RegexSnippet",
10
+ "Transcript",
11
+ "TranscriptGroup",
12
+ ]
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import sys
2
3
  from typing import Any, Literal, TypedDict, cast
3
4
  from uuid import uuid4
@@ -12,8 +13,11 @@ from pydantic import (
12
13
  )
13
14
 
14
15
  from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
15
- from docent.data_models.metadata import BaseAgentRunMetadata
16
- from docent.data_models.transcript import Transcript, TranscriptWithoutMetadataValidator
16
+ from docent.data_models.transcript import (
17
+ Transcript,
18
+ TranscriptWithoutMetadataValidator,
19
+ fake_model_dump,
20
+ )
17
21
 
18
22
 
19
23
  class FilterableField(TypedDict):
@@ -32,7 +36,7 @@ class AgentRun(BaseModel):
32
36
  name: Optional human-readable name for the agent run.
33
37
  description: Optional description of the agent run.
34
38
  transcripts: Dict mapping transcript IDs to Transcript objects.
35
- metadata: Additional structured metadata about the agent run.
39
+ metadata: Additional structured metadata about the agent run as a JSON-serializable dictionary.
36
40
  """
37
41
 
38
42
  id: str = Field(default_factory=lambda: str(uuid4()))
@@ -40,23 +44,34 @@ class AgentRun(BaseModel):
40
44
  description: str | None = None
41
45
 
42
46
  transcripts: dict[str, Transcript]
43
- metadata: BaseAgentRunMetadata
47
+ metadata: dict[str, Any] = Field(default_factory=dict)
44
48
 
45
49
  @field_serializer("metadata")
46
- def serialize_metadata(self, metadata: BaseAgentRunMetadata, _info: Any) -> dict[str, Any]:
50
+ def serialize_metadata(self, metadata: dict[str, Any], _info: Any) -> dict[str, Any]:
47
51
  """
48
- Custom serializer for the metadata field so the internal fields are explicitly preserved.
52
+ Custom serializer for the metadata field - returns the dict as-is since it's already serializable.
49
53
  """
50
- return metadata.model_dump(strip_internal_fields=False)
54
+ return fake_model_dump(metadata)
51
55
 
52
56
  @field_validator("metadata", mode="before")
53
57
  @classmethod
54
- def _validate_metadata_type(cls, v: Any) -> Any:
55
- if v is not None and not isinstance(v, BaseAgentRunMetadata):
56
- raise ValueError(
57
- f"metadata must be an instance of BaseAgentRunMetadata, got {type(v).__name__}"
58
- )
59
- return v
58
+ def _validate_metadata_json_serializable(cls, v: Any) -> dict[str, Any]:
59
+ """
60
+ Validates that metadata is a dictionary and is JSON-serializable.
61
+ """
62
+ if v is None:
63
+ return {}
64
+
65
+ if not isinstance(v, dict):
66
+ raise ValueError(f"metadata must be a dictionary, got {type(v).__name__}")
67
+
68
+ # Check that the metadata is JSON serializable
69
+ try:
70
+ json.dumps(fake_model_dump(cast(dict[str, Any], v)))
71
+ except (TypeError, ValueError) as e:
72
+ raise ValueError(f"metadata must be JSON-serializable: {e}")
73
+
74
+ return cast(dict[str, Any], v)
60
75
 
61
76
  @model_validator(mode="after")
62
77
  def _validate_transcripts_not_empty(self):
@@ -88,16 +103,11 @@ class AgentRun(BaseModel):
88
103
  transcripts_str = "\n\n".join(transcript_strs)
89
104
 
90
105
  # Gather metadata
91
- metadata_obj = self.metadata.model_dump(strip_internal_fields=True)
106
+ metadata_obj = fake_model_dump(self.metadata)
92
107
  if self.name is not None:
93
108
  metadata_obj["name"] = self.name
94
109
  if self.description is not None:
95
110
  metadata_obj["description"] = self.description
96
- # Add the field descriptions if they exist
97
- metadata_obj = {
98
- (f"{k} ({d})" if (d := self.metadata.get_field_description(k)) is not None else k): v
99
- for k, v in metadata_obj.items()
100
- }
101
111
 
102
112
  yaml_width = float("inf")
103
113
  transcripts_str = (
@@ -202,7 +212,7 @@ class AgentRun(BaseModel):
202
212
  _explore_dict(cast(dict[str, Any], v), f"{prefix}.{k}", depth + 1)
203
213
 
204
214
  # Look at the agent run metadata
205
- _explore_dict(self.metadata.model_dump(strip_internal_fields=True), "metadata", 0)
215
+ _explore_dict(fake_model_dump(self.metadata), "metadata", 0)
206
216
  # Look at the transcript metadata
207
217
  # TODO(mengk): restore this later when we have the ability to integrate with SQL.
208
218
  # for t_id, t in self.transcripts.items():
@@ -0,0 +1,229 @@
1
+ # import traceback
2
+ # from typing import Any, Optional
3
+
4
+ # from pydantic import (
5
+ # BaseModel,
6
+ # ConfigDict,
7
+ # Field,
8
+ # PrivateAttr,
9
+ # SerializerFunctionWrapHandler,
10
+ # model_serializer,
11
+ # model_validator,
12
+ # )
13
+
14
+ # from docent._log_util import get_logger
15
+
16
+ # logger = get_logger(__name__)
17
+
18
+ # SINGLETONS = (int, float, str, bool)
19
+
20
+
21
+ # class BaseMetadata(BaseModel):
22
+ # """Provides common functionality for accessing and validating metadata fields.
23
+ # All metadata classes should inherit from this class.
24
+
25
+ # Serialization Behavior:
26
+ # - Field descriptions are highly recommended and stored in serialized versions of the object.
27
+ # - When a subclass of BaseMetadata is uploaded to a server, all extra fields and their descriptions are retained.
28
+ # - To recover the original structure with proper typing upon download, use:
29
+ # `CustomMetadataClass.model_validate(obj.model_dump())`.
30
+
31
+ # Attributes:
32
+ # model_config: Pydantic configuration that allows extra fields.
33
+ # allow_fields_without_descriptions: Boolean indicating whether to allow fields without descriptions.
34
+ # """
35
+
36
+ # model_config = ConfigDict(extra="allow")
37
+ # allow_fields_without_descriptions: bool = True
38
+
39
+ # # Private attribute to store field descriptions
40
+ # _field_descriptions: dict[str, str | None] | None = PrivateAttr(default=None)
41
+ # _internal_basemetadata_fields: set[str] = PrivateAttr(
42
+ # default={
43
+ # "allow_fields_without_descriptions",
44
+ # "model_config",
45
+ # "_field_descriptions",
46
+ # }
47
+ # )
48
+
49
+ # @model_validator(mode="after")
50
+ # def _validate_field_types_and_descriptions(self):
51
+ # """Validates that all fields have descriptions and proper types.
52
+
53
+ # Returns:
54
+ # Self: The validated model instance.
55
+
56
+ # Raises:
57
+ # ValueError: If any field is missing a description or has an invalid type.
58
+ # """
59
+ # # Validate each field in the model
60
+ # for field_name, field_info in self.__class__.model_fields.items():
61
+ # if field_name in self._internal_basemetadata_fields:
62
+ # continue
63
+
64
+ # # Check that field has a description
65
+ # if field_info.description is None:
66
+ # if not self.allow_fields_without_descriptions:
67
+ # raise ValueError(
68
+ # f"Field `{field_name}` needs a description in the definition of `{self.__class__.__name__}`, like `{field_name}: T = Field(description=..., default=...)`. "
69
+ # "To allow un-described fields, set `allow_fields_without_descriptions = True` on the instance or in your metadata class definition."
70
+ # )
71
+
72
+ # # Validate that the metadata is JSON serializable
73
+ # try:
74
+ # self.model_dump_json()
75
+ # except Exception as e:
76
+ # raise ValueError(
77
+ # f"Metadata is not JSON serializable: {e}. Traceback: {traceback.format_exc()}"
78
+ # )
79
+
80
+ # return self
81
+
82
+ # def model_post_init(self, __context: Any) -> None:
83
+ # """Initializes field descriptions from extra data after model initialization.
84
+
85
+ # Args:
86
+ # __context: The context provided by Pydantic's post-initialization hook.
87
+ # """
88
+ # fd = self.model_extra.pop("_field_descriptions", None) if self.model_extra else None
89
+ # if fd is not None:
90
+ # self._field_descriptions = fd
91
+
92
+ # @model_serializer(mode="wrap")
93
+ # def _serialize_model(self, handler: SerializerFunctionWrapHandler):
94
+ # # Call the default serializer
95
+ # data = handler(self)
96
+
97
+ # # Dump the field descriptions
98
+ # if self._field_descriptions is None:
99
+ # self._field_descriptions = self._compute_field_descriptions()
100
+ # data["_field_descriptions"] = self._field_descriptions
101
+
102
+ # return data
103
+
104
+ # def model_dump(
105
+ # self, *args: Any, strip_internal_fields: bool = False, **kwargs: Any
106
+ # ) -> dict[str, Any]:
107
+ # data = super().model_dump(*args, **kwargs)
108
+
109
+ # # Remove internal fields if requested
110
+ # if strip_internal_fields:
111
+ # for field in self._internal_basemetadata_fields:
112
+ # if field in data:
113
+ # data.pop(field)
114
+
115
+ # return data
116
+
117
+ # def get(self, key: str, default_value: Any = None) -> Any:
118
+ # """Gets a value from the metadata by key.
119
+
120
+ # Args:
121
+ # key: The key to look up in the metadata.
122
+ # default_value: Value to return if the key is not found. Defaults to None.
123
+
124
+ # Returns:
125
+ # Any: The value associated with the key, or the default value if not found.
126
+ # """
127
+ # # Check if the field exists in the model's fields
128
+ # if key in self.__class__.model_fields or (
129
+ # self.model_extra is not None and key in self.model_extra
130
+ # ):
131
+ # # Field exists, return its value (even if None)
132
+ # return getattr(self, key)
133
+
134
+ # logger.warning(f"Field '{key}' not found in {self.__class__.__name__}")
135
+ # return default_value
136
+
137
+ # def get_field_description(self, field_name: str) -> str | None:
138
+ # """Gets the description of a field defined in the model schema.
139
+
140
+ # Args:
141
+ # field_name: The name of the field.
142
+
143
+ # Returns:
144
+ # str or None: The description string if the field is defined in the model schema
145
+ # and has a description, otherwise None.
146
+ # """
147
+ # if self._field_descriptions is None:
148
+ # self._field_descriptions = self._compute_field_descriptions()
149
+
150
+ # if field_name in self._field_descriptions:
151
+ # return self._field_descriptions[field_name]
152
+
153
+ # logger.warning(
154
+ # f"Field description for '{field_name}' not found in {self.__class__.__name__}"
155
+ # )
156
+ # return None
157
+
158
+ # def get_all_field_descriptions(self) -> dict[str, str | None]:
159
+ # """Gets descriptions for all fields defined in the model schema.
160
+
161
+ # Returns:
162
+ # dict: A dictionary mapping field names to their descriptions.
163
+ # Only includes fields that have descriptions defined in the schema.
164
+ # """
165
+ # if self._field_descriptions is None:
166
+ # self._field_descriptions = self._compute_field_descriptions()
167
+ # return self._field_descriptions
168
+
169
+ # def _compute_field_descriptions(self) -> dict[str, str | None]:
170
+ # """Computes descriptions for all fields in the model.
171
+
172
+ # Returns:
173
+ # dict: A dictionary mapping field names to their descriptions.
174
+ # """
175
+ # field_descriptions: dict[str, Optional[str]] = {}
176
+ # for field_name, field_info in self.__class__.model_fields.items():
177
+ # if field_name not in self._internal_basemetadata_fields:
178
+ # field_descriptions[field_name] = field_info.description
179
+ # return field_descriptions
180
+
181
+
182
+ # class BaseAgentRunMetadata(BaseMetadata):
183
+ # """Extends BaseMetadata with fields specific to agent evaluation runs.
184
+
185
+ # Attributes:
186
+ # scores: Dictionary of evaluation metrics.
187
+ # """
188
+
189
+ # scores: dict[str, int | float | bool | None] = Field(
190
+ # description="A dict of score_key -> score_value. Use one key for each metric you're tracking."
191
+ # )
192
+
193
+
194
+ # class InspectAgentRunMetadata(BaseAgentRunMetadata):
195
+ # """Extends BaseAgentRunMetadata with fields specific to Inspect runs.
196
+
197
+ # Attributes:
198
+ # task_id: The ID of the 'benchmark' or 'set of evals' that the transcript belongs to
199
+ # sample_id: The specific task inside of the `task_id` benchmark that the transcript was run on
200
+ # epoch_id: Each `sample_id` should be run multiple times due to stochasticity; `epoch_id` is the integer index of a specific run.
201
+ # model: The model that was used to generate the transcript
202
+ # scoring_metadata: Additional metadata about the scoring process
203
+ # additional_metadata: Additional metadata about the transcript
204
+ # """
205
+
206
+ # task_id: str = Field(
207
+ # description="The ID of the 'benchmark' or 'set of evals' that the transcript belongs to"
208
+ # )
209
+
210
+ # # Identification of this particular run
211
+ # sample_id: str = Field(
212
+ # description="The specific task inside of the `task_id` benchmark that the transcript was run on"
213
+ # )
214
+ # epoch_id: int = Field(
215
+ # description="Each `sample_id` should be run multiple times due to stochasticity; `epoch_id` is the integer index of a specific run."
216
+ # )
217
+
218
+ # # Parameters for the run
219
+ # model: str = Field(description="The model that was used to generate the transcript")
220
+
221
+ # # Scoring
222
+ # scoring_metadata: dict[str, Any] | None = Field(
223
+ # description="Additional metadata about the scoring process"
224
+ # )
225
+
226
+ # # Inspect metadata
227
+ # additional_metadata: dict[str, Any] | None = Field(
228
+ # description="Additional metadata about the transcript"
229
+ # )
@@ -11,7 +11,6 @@ from docent.data_models._tiktoken_util import (
11
11
  truncate_to_token_limit,
12
12
  )
13
13
  from docent.data_models.chat import AssistantMessage, ChatMessage, ContentReasoning
14
- from docent.data_models.metadata import BaseMetadata
15
14
 
16
15
  # Template for formatting individual transcript blocks
17
16
  TRANSCRIPT_BLOCK_TEMPLATE = """
@@ -63,6 +62,53 @@ def format_chat_message(
63
62
  )
64
63
 
65
64
 
65
+ class TranscriptGroup(BaseModel):
66
+ """Represents a group of transcripts that are logically related.
67
+
68
+ A transcript group can contain multiple transcripts and can have a hierarchical
69
+ structure with parent groups. This is useful for organizing transcripts into
70
+ logical units like experiments, tasks, or sessions.
71
+
72
+ Attributes:
73
+ id: Unique identifier for the transcript group, auto-generated by default.
74
+ name: Optional human-readable name for the transcript group.
75
+ description: Optional description of the transcript group.
76
+ parent_transcript_group_id: Optional ID of the parent transcript group.
77
+ metadata: Additional structured metadata about the transcript group.
78
+ """
79
+
80
+ id: str = Field(default_factory=lambda: str(uuid4()))
81
+ name: str | None = None
82
+ description: str | None = None
83
+ parent_transcript_group_id: str | None = None
84
+ metadata: dict[str, Any] = Field(default_factory=dict)
85
+
86
+ @field_serializer("metadata")
87
+ def serialize_metadata(self, metadata: dict[str, Any], _info: Any) -> dict[str, Any]:
88
+ """
89
+ Custom serializer for the metadata field so the internal fields are explicitly preserved.
90
+ """
91
+ return fake_model_dump(metadata)
92
+
93
+ @field_validator("metadata", mode="before")
94
+ @classmethod
95
+ def _validate_metadata_type(cls, v: Any) -> Any:
96
+ if v is not None and not isinstance(v, dict):
97
+ raise ValueError(f"metadata must be a dictionary, got {type(v).__name__}")
98
+ return v # type: ignore
99
+
100
+
101
+ def fake_model_dump(obj: dict[str, Any]) -> dict[str, Any]:
102
+ """
103
+ Emulate the action of pydantic.model_dump() for non-pydantic objects (to handle nested values)
104
+ """
105
+
106
+ class _FakeModel(BaseModel):
107
+ data: dict[str, Any]
108
+
109
+ return _FakeModel(data=obj).model_dump()["data"]
110
+
111
+
66
112
  class Transcript(BaseModel):
67
113
  """Represents a transcript of messages in a conversation with an AI agent.
68
114
 
@@ -74,6 +120,7 @@ class Transcript(BaseModel):
74
120
  id: Unique identifier for the transcript, auto-generated by default.
75
121
  name: Optional human-readable name for the transcript.
76
122
  description: Optional description of the transcript.
123
+ transcript_group_id: Optional ID of the transcript group this transcript belongs to.
77
124
  messages: List of chat messages in the transcript.
78
125
  metadata: Additional structured metadata about the transcript.
79
126
  """
@@ -81,27 +128,25 @@ class Transcript(BaseModel):
81
128
  id: str = Field(default_factory=lambda: str(uuid4()))
82
129
  name: str | None = None
83
130
  description: str | None = None
131
+ transcript_group_id: str | None = None
84
132
 
85
133
  messages: list[ChatMessage]
86
- metadata: BaseMetadata = Field(default_factory=BaseMetadata)
87
-
134
+ metadata: dict[str, Any] = Field(default_factory=dict)
88
135
  _units_of_action: list[list[int]] | None = PrivateAttr(default=None)
89
136
 
90
137
  @field_serializer("metadata")
91
- def serialize_metadata(self, metadata: BaseMetadata, _info: Any) -> dict[str, Any]:
138
+ def serialize_metadata(self, metadata: dict[str, Any], _info: Any) -> dict[str, Any]:
92
139
  """
93
140
  Custom serializer for the metadata field so the internal fields are explicitly preserved.
94
141
  """
95
- return metadata.model_dump(strip_internal_fields=False)
142
+ return fake_model_dump(metadata)
96
143
 
97
144
  @field_validator("metadata", mode="before")
98
145
  @classmethod
99
146
  def _validate_metadata_type(cls, v: Any) -> Any:
100
- if v is not None and not isinstance(v, BaseMetadata):
101
- raise ValueError(
102
- f"metadata must be an instance of BaseMetadata, got {type(v).__name__}"
103
- )
104
- return v
147
+ if v is not None and not isinstance(v, dict):
148
+ raise ValueError(f"metadata must be a dict, got {type(v).__name__}")
149
+ return v # type: ignore
105
150
 
106
151
  @property
107
152
  def units_of_action(self) -> list[list[int]]:
@@ -297,12 +342,7 @@ class Transcript(BaseModel):
297
342
  blocks_str = "\n".join(au_blocks)
298
343
 
299
344
  # Gather metadata
300
- metadata_obj = self.metadata.model_dump(strip_internal_fields=True)
301
- # Add the field descriptions if they exist
302
- metadata_obj = {
303
- (f"{k} ({d})" if (d := self.metadata.get_field_description(k)) is not None else k): v
304
- for k, v in metadata_obj.items()
305
- }
345
+ metadata_obj = fake_model_dump(self.metadata)
306
346
 
307
347
  yaml_width = float("inf")
308
348
  block_str = f"<blocks>\n{blocks_str}\n</blocks>\n"
@@ -0,0 +1,88 @@
1
+ from typing import Any
2
+
3
+ from inspect_ai.log import EvalLog
4
+ from inspect_ai.scorer import CORRECT, INCORRECT, NOANSWER, PARTIAL, Score
5
+
6
+ from docent.data_models import AgentRun, Transcript
7
+ from docent.data_models.chat import parse_chat_message
8
+
9
+
10
+ def _normalize_inspect_score(score: Score) -> Any:
11
+ """
12
+ Normalize an inspect score to a float. This implements the same logic as inspect_ai.scorer._metric.value_to_float, but fails more conspicuously.
13
+
14
+ Args:
15
+ score: The inspect score to normalize.
16
+
17
+ Returns:
18
+ The normalized score as a float, or None if the score is not a valid value.
19
+ """
20
+
21
+ def _leaf_normalize(value: int | float | bool | str | None) -> float | str | None:
22
+ if value is None:
23
+ return None
24
+ if isinstance(value, int | float | bool):
25
+ return float(value)
26
+ if value == CORRECT:
27
+ return 1.0
28
+ if value == PARTIAL:
29
+ return 0.5
30
+ if value in [INCORRECT, NOANSWER]:
31
+ return 0
32
+ value = str(value).lower()
33
+ if value in ["yes", "true"]:
34
+ return 1.0
35
+ if value in ["no", "false"]:
36
+ return 0.0
37
+ if value.replace(".", "").isnumeric():
38
+ return float(value)
39
+ return value
40
+
41
+ if isinstance(score.value, int | float | bool | str):
42
+ return _leaf_normalize(score.value)
43
+ if isinstance(score.value, list):
44
+ return [_leaf_normalize(v) for v in score.value]
45
+ assert isinstance(score.value, dict), "Inspect score must be leaf value, list, or dict"
46
+ return {k: _leaf_normalize(v) for k, v in score.value.items()}
47
+
48
+
49
+ def load_inspect_log(log: EvalLog) -> list[AgentRun]:
50
+ if log.samples is None:
51
+ return []
52
+
53
+ # TODO(vincent): fix this
54
+ agent_runs: list[AgentRun] = []
55
+
56
+ for s in log.samples:
57
+ sample_id = s.id
58
+ epoch_id = s.epoch
59
+
60
+ if s.scores is None:
61
+ sample_scores = {}
62
+ else:
63
+ sample_scores = {k: _normalize_inspect_score(v) for k, v in s.scores.items()}
64
+
65
+ metadata = {
66
+ "task_id": log.eval.task,
67
+ "sample_id": str(sample_id),
68
+ "epoch_id": epoch_id,
69
+ "model": log.eval.model,
70
+ "additional_metadata": s.metadata,
71
+ "scores": sample_scores,
72
+ # Scores could have answers, explanations, and other metadata besides the values we extract
73
+ "scoring_metadata": s.scores,
74
+ }
75
+
76
+ agent_runs.append(
77
+ AgentRun(
78
+ transcripts={
79
+ "main": Transcript(
80
+ messages=[parse_chat_message(m.model_dump()) for m in s.messages],
81
+ metadata={},
82
+ )
83
+ },
84
+ metadata=metadata,
85
+ )
86
+ )
87
+
88
+ return agent_runs
@@ -197,75 +197,85 @@ class Docent:
197
197
  return response.json()
198
198
 
199
199
  def list_searches(self, collection_id: str) -> list[dict[str, Any]]:
200
- """List all searches for a given collection.
200
+ """List all rubrics for a given collection.
201
201
 
202
202
  Args:
203
203
  collection_id: ID of the Collection.
204
204
 
205
205
  Returns:
206
- list: List of dictionaries containing search query information.
206
+ list: List of dictionaries containing rubric information.
207
207
 
208
208
  Raises:
209
209
  requests.exceptions.HTTPError: If the API request fails.
210
210
  """
211
- url = f"{self._server_url}/{collection_id}/list_search_queries"
211
+ url = f"{self._server_url}/rubric/{collection_id}/rubrics"
212
212
  response = self._session.get(url)
213
213
  response.raise_for_status()
214
214
  return response.json()
215
215
 
216
- def get_search_results(self, collection_id: str, search_query: str) -> list[dict[str, Any]]:
217
- """Get search results for a given collection and search query.
218
- Pass in either search_query or query_id.
216
+ def get_search_results(
217
+ self, collection_id: str, rubric_id: str, rubric_version: int
218
+ ) -> list[dict[str, Any]]:
219
+ """Get rubric results for a given collection, rubric and version.
219
220
 
220
221
  Args:
221
222
  collection_id: ID of the Collection.
222
- search_query: The search query to get results for.
223
+ rubric_id: The ID of the rubric to get results for.
224
+ rubric_version: The version of the rubric to get results for.
223
225
 
224
226
  Returns:
225
- list: List of dictionaries containing search result information.
227
+ list: List of dictionaries containing rubric result information.
226
228
 
227
229
  Raises:
228
230
  requests.exceptions.HTTPError: If the API request fails.
229
231
  """
230
- url = f"{self._server_url}/{collection_id}/get_search_results"
231
- response = self._session.post(url, json={"search_query": search_query})
232
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/results"
233
+ response = self._session.get(url, params={"rubric_version": rubric_version})
232
234
  response.raise_for_status()
233
235
  return response.json()
234
236
 
235
- def list_search_clusters(self, collection_id: str, search_query: str) -> list[dict[str, Any]]:
236
- """List all search clusters for a given collection.
237
- Pass in either search_query or query_id.
237
+ def list_search_clusters(
238
+ self, collection_id: str, rubric_id: str, rubric_version: int | None = None
239
+ ) -> list[dict[str, Any]]:
240
+ """List all centroids for a given collection and rubric.
238
241
 
239
242
  Args:
240
243
  collection_id: ID of the Collection.
241
- search_query: The search query to get clusters for.
244
+ rubric_id: The ID of the rubric to get centroids for.
245
+ rubric_version: Optional version of the rubric. If not provided, uses latest.
242
246
 
243
247
  Returns:
244
- list: List of dictionaries containing search cluster information.
248
+ list: List of dictionaries containing centroid information.
245
249
 
246
250
  Raises:
247
251
  requests.exceptions.HTTPError: If the API request fails.
248
252
  """
249
- url = f"{self._server_url}/{collection_id}/list_search_clusters"
250
- response = self._session.post(url, json={"search_query": search_query})
253
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/centroids"
254
+ params: dict[str, int] = {}
255
+ if rubric_version is not None:
256
+ params["rubric_version"] = rubric_version
257
+ response = self._session.get(url, params=params)
251
258
  response.raise_for_status()
252
259
  return response.json()
253
260
 
254
- def get_cluster_matches(self, collection_id: str, centroid: str) -> list[dict[str, Any]]:
255
- """Get the matches for a given cluster.
261
+ def get_cluster_matches(
262
+ self, collection_id: str, rubric_id: str, rubric_version: int
263
+ ) -> list[dict[str, Any]]:
264
+ """Get centroid assignments for a given rubric.
256
265
 
257
266
  Args:
258
267
  collection_id: ID of the Collection.
259
- cluster_id: The ID of the cluster to get matches for.
268
+ rubric_id: The ID of the rubric to get assignments for.
269
+ rubric_version: The version of the rubric to get assignments for.
260
270
 
261
271
  Returns:
262
- list: List of dictionaries containing the search results that match the cluster.
272
+ list: List of dictionaries containing centroid assignment information.
263
273
 
264
274
  Raises:
265
275
  requests.exceptions.HTTPError: If the API request fails.
266
276
  """
267
- url = f"{self._server_url}/{collection_id}/get_cluster_matches"
268
- response = self._session.post(url, json={"centroid": centroid})
277
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/assignments"
278
+ response = self._session.get(url, params={"rubric_version": rubric_version})
269
279
  response.raise_for_status()
270
280
  return response.json()
271
281