docent-python 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/__init__.py +3 -0
- docent/_log_util/__init__.py +3 -0
- docent/_log_util/logger.py +141 -0
- docent/data_models/__init__.py +25 -0
- docent/data_models/_tiktoken_util.py +91 -0
- docent/data_models/agent_run.py +231 -0
- docent/data_models/chat/__init__.py +25 -0
- docent/data_models/chat/content.py +56 -0
- docent/data_models/chat/message.py +125 -0
- docent/data_models/chat/tool.py +109 -0
- docent/data_models/citation.py +223 -0
- docent/data_models/filters.py +205 -0
- docent/data_models/metadata.py +219 -0
- docent/data_models/regex.py +56 -0
- docent/data_models/shared_types.py +10 -0
- docent/data_models/transcript.py +347 -0
- docent/py.typed +0 -0
- docent/sdk/__init__.py +0 -0
- docent/sdk/client.py +285 -0
- docent_python-0.1.0a1.dist-info/METADATA +16 -0
- docent_python-0.1.0a1.dist-info/RECORD +23 -0
- docent_python-0.1.0a1.dist-info/WHEEL +4 -0
- docent_python-0.1.0a1.dist-info/licenses/LICENSE.md +7 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import (
|
|
5
|
+
BaseModel,
|
|
6
|
+
ConfigDict,
|
|
7
|
+
Field,
|
|
8
|
+
PrivateAttr,
|
|
9
|
+
SerializerFunctionWrapHandler,
|
|
10
|
+
model_serializer,
|
|
11
|
+
model_validator,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from docent._log_util import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
SINGLETONS = (int, float, str, bool)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseMetadata(BaseModel):
|
|
22
|
+
"""Provides common functionality for accessing and validating metadata fields.
|
|
23
|
+
All metadata classes should inherit from this class.
|
|
24
|
+
|
|
25
|
+
Serialization Behavior:
|
|
26
|
+
- Field descriptions are highly recommended and stored in serialized versions of the object.
|
|
27
|
+
- When a subclass of BaseMetadata is uploaded to a server, all extra fields and their descriptions are retained.
|
|
28
|
+
- To recover the original structure with proper typing upon download, use:
|
|
29
|
+
`CustomMetadataClass.model_validate(obj.model_dump())`.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
model_config: Pydantic configuration that allows extra fields.
|
|
33
|
+
allow_fields_without_descriptions: Boolean indicating whether to allow fields without descriptions.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
model_config = ConfigDict(extra="allow")
|
|
37
|
+
allow_fields_without_descriptions: bool = False
|
|
38
|
+
|
|
39
|
+
# Private attribute to store field descriptions
|
|
40
|
+
_field_descriptions: dict[str, str | None] | None = PrivateAttr(default=None)
|
|
41
|
+
_internal_basemetadata_fields: set[str] = PrivateAttr(
|
|
42
|
+
default={
|
|
43
|
+
"allow_fields_without_descriptions",
|
|
44
|
+
"model_config",
|
|
45
|
+
"_field_descriptions",
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@model_validator(mode="after")
|
|
50
|
+
def _validate_field_types_and_descriptions(self):
|
|
51
|
+
"""Validates that all fields have descriptions and proper types.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Self: The validated model instance.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If any field is missing a description or has an invalid type.
|
|
58
|
+
"""
|
|
59
|
+
# Validate each field in the model
|
|
60
|
+
for field_name, field_info in self.__class__.model_fields.items():
|
|
61
|
+
if field_name in self._internal_basemetadata_fields:
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
# Check that field has a description
|
|
65
|
+
if field_info.description is None:
|
|
66
|
+
if not self.allow_fields_without_descriptions:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Field `{field_name}` needs a description in the definition of `{self.__class__.__name__}`, like `{field_name}: T = Field(description=..., default=...)`. "
|
|
69
|
+
"To allow un-described fields, set `allow_fields_without_descriptions = True` on the instance or in your metadata class definition."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Validate that the metadata is JSON serializable
|
|
73
|
+
try:
|
|
74
|
+
self.model_dump_json()
|
|
75
|
+
except Exception as e:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Metadata is not JSON serializable: {e}. Traceback: {traceback.format_exc()}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return self
|
|
81
|
+
|
|
82
|
+
def model_post_init(self, __context: Any) -> None:
|
|
83
|
+
"""Initializes field descriptions from extra data after model initialization.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
__context: The context provided by Pydantic's post-initialization hook.
|
|
87
|
+
"""
|
|
88
|
+
fd = self.model_extra.pop("_field_descriptions", None) if self.model_extra else None
|
|
89
|
+
if fd is not None:
|
|
90
|
+
self._field_descriptions = fd
|
|
91
|
+
|
|
92
|
+
@model_serializer(mode="wrap")
|
|
93
|
+
def _serialize_model(self, handler: SerializerFunctionWrapHandler):
|
|
94
|
+
# Call the default serializer
|
|
95
|
+
data = handler(self)
|
|
96
|
+
|
|
97
|
+
# Dump the field descriptions
|
|
98
|
+
if self._field_descriptions is None:
|
|
99
|
+
self._field_descriptions = self._compute_field_descriptions()
|
|
100
|
+
data["_field_descriptions"] = self._field_descriptions
|
|
101
|
+
|
|
102
|
+
return data
|
|
103
|
+
|
|
104
|
+
def model_dump(
|
|
105
|
+
self, *args: Any, strip_internal_fields: bool = False, **kwargs: Any
|
|
106
|
+
) -> dict[str, Any]:
|
|
107
|
+
data = super().model_dump(*args, **kwargs)
|
|
108
|
+
|
|
109
|
+
# Remove internal fields if requested
|
|
110
|
+
if strip_internal_fields:
|
|
111
|
+
for field in self._internal_basemetadata_fields:
|
|
112
|
+
if field in data:
|
|
113
|
+
data.pop(field)
|
|
114
|
+
|
|
115
|
+
return data
|
|
116
|
+
|
|
117
|
+
def get(self, key: str, default_value: Any = None) -> Any:
|
|
118
|
+
"""Gets a value from the metadata by key.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
key: The key to look up in the metadata.
|
|
122
|
+
default_value: Value to return if the key is not found. Defaults to None.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Any: The value associated with the key, or the default value if not found.
|
|
126
|
+
"""
|
|
127
|
+
# Check if the field exists in the model's fields
|
|
128
|
+
if key in self.__class__.model_fields or (
|
|
129
|
+
self.model_extra is not None and key in self.model_extra
|
|
130
|
+
):
|
|
131
|
+
# Field exists, return its value (even if None)
|
|
132
|
+
return getattr(self, key)
|
|
133
|
+
|
|
134
|
+
logger.warning(f"Field '{key}' not found in {self.__class__.__name__}")
|
|
135
|
+
return default_value
|
|
136
|
+
|
|
137
|
+
def get_field_description(self, field_name: str) -> str | None:
|
|
138
|
+
"""Gets the description of a field defined in the model schema.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
field_name: The name of the field.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
str or None: The description string if the field is defined in the model schema
|
|
145
|
+
and has a description, otherwise None.
|
|
146
|
+
"""
|
|
147
|
+
if self._field_descriptions is None:
|
|
148
|
+
self._field_descriptions = self._compute_field_descriptions()
|
|
149
|
+
|
|
150
|
+
if field_name in self._field_descriptions:
|
|
151
|
+
return self._field_descriptions[field_name]
|
|
152
|
+
|
|
153
|
+
logger.warning(
|
|
154
|
+
f"Field description for '{field_name}' not found in {self.__class__.__name__}"
|
|
155
|
+
)
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
def get_all_field_descriptions(self) -> dict[str, str | None]:
|
|
159
|
+
"""Gets descriptions for all fields defined in the model schema.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
dict: A dictionary mapping field names to their descriptions.
|
|
163
|
+
Only includes fields that have descriptions defined in the schema.
|
|
164
|
+
"""
|
|
165
|
+
if self._field_descriptions is None:
|
|
166
|
+
self._field_descriptions = self._compute_field_descriptions()
|
|
167
|
+
return self._field_descriptions
|
|
168
|
+
|
|
169
|
+
def _compute_field_descriptions(self) -> dict[str, str | None]:
|
|
170
|
+
"""Computes descriptions for all fields in the model.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
dict: A dictionary mapping field names to their descriptions.
|
|
174
|
+
"""
|
|
175
|
+
field_descriptions: dict[str, Optional[str]] = {}
|
|
176
|
+
for field_name, field_info in self.__class__.model_fields.items():
|
|
177
|
+
if field_name not in self._internal_basemetadata_fields:
|
|
178
|
+
field_descriptions[field_name] = field_info.description
|
|
179
|
+
return field_descriptions
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class BaseAgentRunMetadata(BaseMetadata):
|
|
183
|
+
"""Extends BaseMetadata with fields specific to agent evaluation runs.
|
|
184
|
+
|
|
185
|
+
Attributes:
|
|
186
|
+
scores: Dictionary of evaluation metrics.
|
|
187
|
+
default_score_key: The primary evaluation metric key.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
scores: dict[str, int | float | bool | None] = Field(
|
|
191
|
+
description="A dict of score_key -> score_value. Use one key for each metric you're tracking."
|
|
192
|
+
)
|
|
193
|
+
default_score_key: str | None = Field(
|
|
194
|
+
description="The default score key for the transcript; one top-line metric"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def get_default_score(self) -> int | float | bool | None:
|
|
198
|
+
"""Gets the default evaluation score.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
int, float, bool, or None: The value of the default score if a default score key is set,
|
|
202
|
+
otherwise None.
|
|
203
|
+
"""
|
|
204
|
+
if self.default_score_key is None:
|
|
205
|
+
return None
|
|
206
|
+
return self.scores.get(self.default_score_key)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class FrameDimension(BaseModel):
|
|
210
|
+
"""A dimension for organizing agent runs."""
|
|
211
|
+
|
|
212
|
+
id: str
|
|
213
|
+
name: str
|
|
214
|
+
search_query: str | None = None
|
|
215
|
+
metadata_key: str | None = None
|
|
216
|
+
maintain_mece: bool | None = None
|
|
217
|
+
loading_clusters: bool = False
|
|
218
|
+
loading_bins: bool = False
|
|
219
|
+
binIds: list[dict[str, Any]] | None = None
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from docent._log_util import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RegexSnippet(BaseModel):
|
|
11
|
+
snippet: str
|
|
12
|
+
match_start: int
|
|
13
|
+
match_end: int
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_regex_snippets(text: str, pattern: str, window_size: int = 50) -> list[RegexSnippet]:
|
|
17
|
+
"""Extracts snippets from text that match a regex pattern, with surrounding context.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
text: The text to search in.
|
|
21
|
+
pattern: The regex pattern to match.
|
|
22
|
+
window_size: The number of characters to include before and after the match.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
A list of RegexSnippet objects containing the snippets and match positions.
|
|
26
|
+
"""
|
|
27
|
+
# Find all matches
|
|
28
|
+
try:
|
|
29
|
+
matches = list(re.compile(pattern, re.IGNORECASE | re.DOTALL).finditer(text))
|
|
30
|
+
if not matches:
|
|
31
|
+
logger.warning(f"No regex matches found for {pattern}: this shouldn't happen!")
|
|
32
|
+
|
|
33
|
+
if not matches:
|
|
34
|
+
return []
|
|
35
|
+
|
|
36
|
+
snippets: list[RegexSnippet] = []
|
|
37
|
+
for match in matches:
|
|
38
|
+
start, end = match.span()
|
|
39
|
+
|
|
40
|
+
# Calculate window around the match
|
|
41
|
+
snippet_start = max(0, start - window_size)
|
|
42
|
+
snippet_end = min(len(text), end + window_size)
|
|
43
|
+
|
|
44
|
+
# Create the snippet with the match indices adjusted for the window
|
|
45
|
+
snippets.append(
|
|
46
|
+
RegexSnippet(
|
|
47
|
+
snippet=text[snippet_start:snippet_end],
|
|
48
|
+
match_start=start - snippet_start,
|
|
49
|
+
match_end=end - snippet_start,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return snippets
|
|
54
|
+
except re.error as e:
|
|
55
|
+
logger.error(f"Got regex error: {e}")
|
|
56
|
+
return []
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from pydantic import BaseModel, Field, PrivateAttr, field_serializer, field_validator
|
|
7
|
+
|
|
8
|
+
from docent.data_models._tiktoken_util import (
|
|
9
|
+
get_token_count,
|
|
10
|
+
group_messages_into_ranges,
|
|
11
|
+
truncate_to_token_limit,
|
|
12
|
+
)
|
|
13
|
+
from docent.data_models.chat import AssistantMessage, ChatMessage, ContentReasoning
|
|
14
|
+
from docent.data_models.metadata import BaseMetadata
|
|
15
|
+
|
|
16
|
+
# Template for formatting individual transcript blocks
|
|
17
|
+
TRANSCRIPT_BLOCK_TEMPLATE = """
|
|
18
|
+
<{index_label} | role: {role}>
|
|
19
|
+
{content}
|
|
20
|
+
</{index_label}>
|
|
21
|
+
""".strip()
|
|
22
|
+
|
|
23
|
+
# Instructions for citing single transcript blocks
|
|
24
|
+
SINGLE_RUN_CITE_INSTRUCTION = "Each transcript and each block has a unique index. Cite the relevant indices in brackets when relevant, like [T<idx>B<idx>]. Use multiple tags to cite multiple blocks, like [T<idx1>B<idx1>][T<idx2>B<idx2>]. Use an inner dash to cite a range of blocks, like [T<idx1>B<idx1>-T<idx2>B<idx2>]. Remember to cite specific blocks and NOT action units."
|
|
25
|
+
|
|
26
|
+
# Instructions for citing multiple transcript blocks
|
|
27
|
+
MULTI_RUN_CITE_INSTRUCTION = "Each run, each transcript, and each block has a unique index. Cite the relevant indices in brackets when relevant, like [R<idx>T<idx>B<idx>]. Use multiple tags to cite multiple blocks, like [R<idx1>T<idx1>B<idx1>][R<idx2>T<idx2>B<idx2>]. Use an inner dash to cite a range of blocks, like [R<idx1>T<idx1>B<idx1>-R<idx2>T<idx2>B<idx2>]. Remember to cite specific blocks and NOT action units."
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def format_chat_message(
|
|
31
|
+
message: ChatMessage,
|
|
32
|
+
block_idx: int,
|
|
33
|
+
transcript_idx: int = 0,
|
|
34
|
+
agent_run_idx: int | None = None,
|
|
35
|
+
) -> str:
|
|
36
|
+
if agent_run_idx is not None:
|
|
37
|
+
index_label = f"R{agent_run_idx}T{transcript_idx}B{block_idx}"
|
|
38
|
+
else:
|
|
39
|
+
index_label = f"T{transcript_idx}B{block_idx}"
|
|
40
|
+
|
|
41
|
+
cur_content = ""
|
|
42
|
+
|
|
43
|
+
# Add reasoning at beginning if applicable
|
|
44
|
+
if isinstance(message, AssistantMessage) and message.content:
|
|
45
|
+
for content in message.content:
|
|
46
|
+
if isinstance(content, ContentReasoning):
|
|
47
|
+
cur_content = f"<reasoning>\n{content.reasoning}\n</reasoning>\n"
|
|
48
|
+
|
|
49
|
+
# Main content text
|
|
50
|
+
cur_content += message.text
|
|
51
|
+
|
|
52
|
+
# Update content in case there's a view
|
|
53
|
+
if isinstance(message, AssistantMessage) and message.tool_calls:
|
|
54
|
+
for tool_call in message.tool_calls:
|
|
55
|
+
if tool_call.view:
|
|
56
|
+
cur_content += f"\n<tool call>\n{tool_call.view.content}\n</tool call>"
|
|
57
|
+
else:
|
|
58
|
+
args = ", ".join([f"{k}={v}" for k, v in tool_call.arguments.items()])
|
|
59
|
+
cur_content += f"\n<tool call>\n{tool_call.function}({args})\n</tool call>"
|
|
60
|
+
|
|
61
|
+
return TRANSCRIPT_BLOCK_TEMPLATE.format(
|
|
62
|
+
index_label=index_label, role=message.role, content=cur_content
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Transcript(BaseModel):
|
|
67
|
+
"""Represents a transcript of messages in a conversation with an AI agent.
|
|
68
|
+
|
|
69
|
+
A transcript contains a sequence of messages exchanged between different roles
|
|
70
|
+
(system, user, assistant, tool) and provides methods to organize these messages
|
|
71
|
+
into logical units of action.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
id: Unique identifier for the transcript, auto-generated by default.
|
|
75
|
+
name: Optional human-readable name for the transcript.
|
|
76
|
+
description: Optional description of the transcript.
|
|
77
|
+
messages: List of chat messages in the transcript.
|
|
78
|
+
metadata: Additional structured metadata about the transcript.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
82
|
+
name: str | None = None
|
|
83
|
+
description: str | None = None
|
|
84
|
+
|
|
85
|
+
messages: list[ChatMessage]
|
|
86
|
+
metadata: BaseMetadata = Field(default_factory=BaseMetadata)
|
|
87
|
+
|
|
88
|
+
_units_of_action: list[list[int]] | None = PrivateAttr(default=None)
|
|
89
|
+
|
|
90
|
+
@field_serializer("metadata")
|
|
91
|
+
def serialize_metadata(self, metadata: BaseMetadata, _info: Any) -> dict[str, Any]:
|
|
92
|
+
"""
|
|
93
|
+
Custom serializer for the metadata field so the internal fields are explicitly preserved.
|
|
94
|
+
"""
|
|
95
|
+
return metadata.model_dump(strip_internal_fields=False)
|
|
96
|
+
|
|
97
|
+
@field_validator("metadata", mode="before")
|
|
98
|
+
@classmethod
|
|
99
|
+
def _validate_metadata_type(cls, v: Any) -> Any:
|
|
100
|
+
if v is not None and not isinstance(v, BaseMetadata):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"metadata must be an instance of BaseMetadata, got {type(v).__name__}"
|
|
103
|
+
)
|
|
104
|
+
return v
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def units_of_action(self) -> list[list[int]]:
|
|
108
|
+
"""Get the units of action in the transcript.
|
|
109
|
+
|
|
110
|
+
A unit of action represents a logical group of messages, such as a system message
|
|
111
|
+
on its own or a user message followed by assistant responses and tool outputs.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
list[list[int]]: List of units of action, where each unit is a list of message indices.
|
|
115
|
+
"""
|
|
116
|
+
if self._units_of_action is None:
|
|
117
|
+
self._units_of_action = self._compute_units_of_action()
|
|
118
|
+
return self._units_of_action
|
|
119
|
+
|
|
120
|
+
def __init__(self, *args: Any, **kwargs: Any):
|
|
121
|
+
super().__init__(*args, **kwargs)
|
|
122
|
+
self._units_of_action = self._compute_units_of_action()
|
|
123
|
+
|
|
124
|
+
def _compute_units_of_action(self) -> list[list[int]]:
|
|
125
|
+
"""Compute the units of action in the transcript.
|
|
126
|
+
|
|
127
|
+
A unit of action is defined as:
|
|
128
|
+
- A system prompt by itself
|
|
129
|
+
- A group consisting of a user message, assistant response, and any associated tool outputs
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
list[list[int]]: A list of units of action, where each unit is a list of message indices.
|
|
133
|
+
"""
|
|
134
|
+
if not self.messages:
|
|
135
|
+
return []
|
|
136
|
+
|
|
137
|
+
units: list[list[int]] = []
|
|
138
|
+
current_unit: list[int] = []
|
|
139
|
+
|
|
140
|
+
def _start_new_unit():
|
|
141
|
+
nonlocal current_unit
|
|
142
|
+
if current_unit:
|
|
143
|
+
units.append(current_unit.copy())
|
|
144
|
+
current_unit = []
|
|
145
|
+
|
|
146
|
+
for i, message in enumerate(self.messages):
|
|
147
|
+
role = message.role
|
|
148
|
+
prev_message = self.messages[i - 1] if i > 0 else None
|
|
149
|
+
|
|
150
|
+
# System messages are their own unit
|
|
151
|
+
if role == "system":
|
|
152
|
+
assert not current_unit, "System message should be the first message"
|
|
153
|
+
units.append([i])
|
|
154
|
+
|
|
155
|
+
# User message always starts a new unit UNLESS the previous message was a user message
|
|
156
|
+
elif role == "user":
|
|
157
|
+
if current_unit and prev_message and prev_message.role != "user":
|
|
158
|
+
_start_new_unit()
|
|
159
|
+
current_unit.append(i)
|
|
160
|
+
|
|
161
|
+
# Start a new unit if the previous message was not a user or assistant message
|
|
162
|
+
elif role == "assistant":
|
|
163
|
+
if (
|
|
164
|
+
current_unit
|
|
165
|
+
and prev_message
|
|
166
|
+
and prev_message.role != "user"
|
|
167
|
+
and prev_message.role != "assistant"
|
|
168
|
+
):
|
|
169
|
+
_start_new_unit()
|
|
170
|
+
current_unit.append(i)
|
|
171
|
+
|
|
172
|
+
# Tool messages are part of the current unit
|
|
173
|
+
elif role == "tool":
|
|
174
|
+
current_unit.append(i)
|
|
175
|
+
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unknown message role: {role}")
|
|
178
|
+
|
|
179
|
+
# Add the last unit if it exists
|
|
180
|
+
_start_new_unit()
|
|
181
|
+
|
|
182
|
+
return units
|
|
183
|
+
|
|
184
|
+
def get_first_block_in_action_unit(self, action_unit_idx: int) -> int | None:
|
|
185
|
+
"""Get the index of the first message in a given action unit.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
action_unit_idx: The index of the action unit.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
int | None: The index of the first message in the action unit,
|
|
192
|
+
or None if the action unit doesn't exist.
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
IndexError: If the action unit index is out of range.
|
|
196
|
+
"""
|
|
197
|
+
if not self._units_of_action:
|
|
198
|
+
self._units_of_action = self._compute_units_of_action()
|
|
199
|
+
|
|
200
|
+
if 0 <= action_unit_idx < len(self._units_of_action):
|
|
201
|
+
unit = self._units_of_action[action_unit_idx]
|
|
202
|
+
return unit[0] if unit else None
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
def get_action_unit_for_block(self, block_idx: int) -> int | None:
|
|
206
|
+
"""Find the action unit that contains the specified message block.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
block_idx: The index of the message block to find.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
int | None: The index of the action unit containing the block,
|
|
213
|
+
or None if no action unit contains the block.
|
|
214
|
+
"""
|
|
215
|
+
if not self._units_of_action:
|
|
216
|
+
self._units_of_action = self._compute_units_of_action()
|
|
217
|
+
|
|
218
|
+
for unit_idx, unit in enumerate(self._units_of_action):
|
|
219
|
+
if block_idx in unit:
|
|
220
|
+
return unit_idx
|
|
221
|
+
return None
|
|
222
|
+
|
|
223
|
+
def set_messages(self, messages: list[ChatMessage]):
|
|
224
|
+
"""Set the messages in the transcript and recompute units of action.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
messages: The new list of chat messages to set.
|
|
228
|
+
"""
|
|
229
|
+
self.messages = messages
|
|
230
|
+
self._units_of_action = self._compute_units_of_action()
|
|
231
|
+
|
|
232
|
+
def to_str(
|
|
233
|
+
self,
|
|
234
|
+
transcript_idx: int = 0,
|
|
235
|
+
agent_run_idx: int | None = None,
|
|
236
|
+
highlight_action_unit: int | None = None,
|
|
237
|
+
) -> str:
|
|
238
|
+
return self.to_str_with_token_limit(
|
|
239
|
+
token_limit=sys.maxsize,
|
|
240
|
+
agent_run_idx=agent_run_idx,
|
|
241
|
+
transcript_idx=transcript_idx,
|
|
242
|
+
highlight_action_unit=highlight_action_unit,
|
|
243
|
+
)[0]
|
|
244
|
+
|
|
245
|
+
def to_str_with_token_limit(
|
|
246
|
+
self,
|
|
247
|
+
token_limit: int,
|
|
248
|
+
transcript_idx: int = 0,
|
|
249
|
+
agent_run_idx: int | None = None,
|
|
250
|
+
highlight_action_unit: int | None = None,
|
|
251
|
+
) -> list[str]:
|
|
252
|
+
"""Represents the transcript as a list of strings, each of which is at most token_limit tokens
|
|
253
|
+
under the GPT-4 tokenization scheme.
|
|
254
|
+
|
|
255
|
+
We'll try to split up long transcripts along message boundaries and include metadata.
|
|
256
|
+
For very long messages, we'll have to truncate them and remove metadata.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
list[str]: A list of strings, each of which is at most token_limit tokens
|
|
260
|
+
under the GPT-4 tokenization scheme.
|
|
261
|
+
"""
|
|
262
|
+
if highlight_action_unit is not None and not (
|
|
263
|
+
0 <= highlight_action_unit < len(self._units_of_action or [])
|
|
264
|
+
):
|
|
265
|
+
raise ValueError(f"Invalid action unit index: {highlight_action_unit}")
|
|
266
|
+
|
|
267
|
+
# Format blocks by units of action
|
|
268
|
+
au_blocks: list[str] = []
|
|
269
|
+
for unit_idx, unit in enumerate(self._units_of_action or []):
|
|
270
|
+
unit_blocks: list[str] = []
|
|
271
|
+
for msg_idx in unit:
|
|
272
|
+
unit_blocks.append(
|
|
273
|
+
format_chat_message(
|
|
274
|
+
self.messages[msg_idx],
|
|
275
|
+
msg_idx,
|
|
276
|
+
transcript_idx,
|
|
277
|
+
agent_run_idx,
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
unit_content = "\n".join(unit_blocks)
|
|
282
|
+
|
|
283
|
+
# Add highlighting if requested
|
|
284
|
+
if highlight_action_unit and unit_idx == highlight_action_unit:
|
|
285
|
+
blocks_str_template = "<HIGHLIGHTED>\n{}\n</HIGHLIGHTED>"
|
|
286
|
+
else:
|
|
287
|
+
blocks_str_template = "{}"
|
|
288
|
+
au_blocks.append(
|
|
289
|
+
blocks_str_template.format(
|
|
290
|
+
f"<action unit {unit_idx}>\n{unit_content}\n</action unit {unit_idx}>"
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
blocks_str = "\n".join(au_blocks)
|
|
294
|
+
|
|
295
|
+
# Gather metadata
|
|
296
|
+
metadata_obj = self.metadata.model_dump(strip_internal_fields=True)
|
|
297
|
+
# Add the field descriptions if they exist
|
|
298
|
+
metadata_obj = {
|
|
299
|
+
(f"{k} ({d})" if (d := self.metadata.get_field_description(k)) is not None else k): v
|
|
300
|
+
for k, v in metadata_obj.items()
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
yaml_width = float("inf")
|
|
304
|
+
block_str = f"<blocks>\n{blocks_str}\n</blocks>\n"
|
|
305
|
+
metadata_str = f"<metadata>\n{yaml.dump(metadata_obj, width=yaml_width)}\n</metadata>"
|
|
306
|
+
|
|
307
|
+
if token_limit == sys.maxsize:
|
|
308
|
+
return [f"{block_str}" f"{metadata_str}"]
|
|
309
|
+
|
|
310
|
+
metadata_token_count = get_token_count(metadata_str)
|
|
311
|
+
block_token_count = get_token_count(block_str)
|
|
312
|
+
|
|
313
|
+
if metadata_token_count + block_token_count <= token_limit:
|
|
314
|
+
return [f"{block_str}" f"{metadata_str}"]
|
|
315
|
+
else:
|
|
316
|
+
results: list[str] = []
|
|
317
|
+
block_token_counts = [get_token_count(block) for block in au_blocks]
|
|
318
|
+
ranges = group_messages_into_ranges(
|
|
319
|
+
block_token_counts, metadata_token_count, token_limit
|
|
320
|
+
)
|
|
321
|
+
for msg_range in ranges:
|
|
322
|
+
if msg_range.include_metadata:
|
|
323
|
+
cur_au_blocks = "\n".join(au_blocks[msg_range.start : msg_range.end])
|
|
324
|
+
results.append(f"<blocks>\n{cur_au_blocks}\n</blocks>\n" f"{metadata_str}")
|
|
325
|
+
else:
|
|
326
|
+
assert (
|
|
327
|
+
msg_range.end == msg_range.start + 1
|
|
328
|
+
), "Ranges without metadata should be a single message"
|
|
329
|
+
result = str(au_blocks[msg_range.start])
|
|
330
|
+
if msg_range.num_tokens > token_limit - 10:
|
|
331
|
+
result = truncate_to_token_limit(result, token_limit - 10)
|
|
332
|
+
results.append(f"<blocks>\n{result}\n</blocks>\n")
|
|
333
|
+
|
|
334
|
+
return results
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class TranscriptWithoutMetadataValidator(Transcript):
|
|
338
|
+
"""
|
|
339
|
+
A version of Transcript that doesn't have the model_validator on metadata.
|
|
340
|
+
Needed for sending/receiving transcripts via JSON, since they incorrectly trip the existing model_validator.
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
@field_validator("metadata", mode="before")
|
|
344
|
+
@classmethod
|
|
345
|
+
def _validate_metadata_type(cls, v: Any) -> Any:
|
|
346
|
+
# Bypass the model_validator
|
|
347
|
+
return v
|
docent/py.typed
ADDED
|
File without changes
|
docent/sdk/__init__.py
ADDED
|
File without changes
|