docent-python 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/__init__.py +3 -0
- docent/_log_util/__init__.py +3 -0
- docent/_log_util/logger.py +141 -0
- docent/data_models/__init__.py +25 -0
- docent/data_models/_tiktoken_util.py +91 -0
- docent/data_models/agent_run.py +231 -0
- docent/data_models/chat/__init__.py +25 -0
- docent/data_models/chat/content.py +56 -0
- docent/data_models/chat/message.py +125 -0
- docent/data_models/chat/tool.py +109 -0
- docent/data_models/citation.py +223 -0
- docent/data_models/filters.py +205 -0
- docent/data_models/metadata.py +219 -0
- docent/data_models/regex.py +56 -0
- docent/data_models/shared_types.py +10 -0
- docent/data_models/transcript.py +347 -0
- docent/py.typed +0 -0
- docent/sdk/__init__.py +0 -0
- docent/sdk/client.py +285 -0
- docent_python-0.1.0a1.dist-info/METADATA +16 -0
- docent_python-0.1.0a1.dist-info/RECORD +23 -0
- docent_python-0.1.0a1.dist-info/WHEEL +4 -0
- docent_python-0.1.0a1.dist-info/licenses/LICENSE.md +7 -0
docent/__init__.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, Literal, MutableMapping, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ColorCode:
|
|
9
|
+
fore: str
|
|
10
|
+
style: str = ""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Colors:
|
|
14
|
+
# Foreground colors
|
|
15
|
+
BLACK = ColorCode("\033[30m")
|
|
16
|
+
RED = ColorCode("\033[31m")
|
|
17
|
+
GREEN = ColorCode("\033[32m")
|
|
18
|
+
YELLOW = ColorCode("\033[33m")
|
|
19
|
+
BLUE = ColorCode("\033[34m")
|
|
20
|
+
MAGENTA = ColorCode("\033[35m")
|
|
21
|
+
CYAN = ColorCode("\033[36m")
|
|
22
|
+
WHITE = ColorCode("\033[37m")
|
|
23
|
+
BRIGHT_MAGENTA = ColorCode("\033[95m")
|
|
24
|
+
BRIGHT_CYAN = ColorCode("\033[96m")
|
|
25
|
+
|
|
26
|
+
# Styles
|
|
27
|
+
BOLD = "\033[1m"
|
|
28
|
+
RESET = "\033[0m"
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def apply(text: str, color: ColorCode) -> str:
|
|
32
|
+
return f"{color.style}{color.fore}{text}{Colors.RESET}"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ColoredFormatter(logging.Formatter):
|
|
36
|
+
COLORS: Dict[int, ColorCode] = {
|
|
37
|
+
logging.DEBUG: Colors.BLUE,
|
|
38
|
+
logging.INFO: Colors.GREEN,
|
|
39
|
+
logging.WARNING: Colors.YELLOW,
|
|
40
|
+
logging.ERROR: Colors.RED,
|
|
41
|
+
logging.CRITICAL: ColorCode("\033[31m", Colors.BOLD),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
# Available highlight colors
|
|
45
|
+
HIGHLIGHT_COLORS: Dict[str, ColorCode] = {
|
|
46
|
+
"magenta": ColorCode(Colors.BRIGHT_MAGENTA.fore, Colors.BOLD),
|
|
47
|
+
"cyan": ColorCode(Colors.BRIGHT_CYAN.fore, Colors.BOLD),
|
|
48
|
+
"yellow": ColorCode(Colors.YELLOW.fore, Colors.BOLD),
|
|
49
|
+
"red": ColorCode(Colors.RED.fore, Colors.BOLD),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def __init__(self, fmt: Optional[str] = None) -> None:
|
|
53
|
+
super().__init__(
|
|
54
|
+
fmt or "%(asctime)s [%(levelname)s] %(namespace)s: %(message)s", datefmt="%H:%M:%S"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
58
|
+
# Add namespace to extra fields if not present
|
|
59
|
+
if not getattr(record, "namespace", None):
|
|
60
|
+
record.__dict__["namespace"] = record.name
|
|
61
|
+
|
|
62
|
+
# Color the level name
|
|
63
|
+
record.levelname = Colors.apply(record.levelname, self.COLORS[record.levelno])
|
|
64
|
+
|
|
65
|
+
# Color the namespace
|
|
66
|
+
record.__dict__["namespace"] = Colors.apply(record.__dict__["namespace"], Colors.CYAN)
|
|
67
|
+
|
|
68
|
+
# Check if highlight flag is set
|
|
69
|
+
highlight = getattr(record, "highlight", None)
|
|
70
|
+
if highlight:
|
|
71
|
+
# Get the highlight color or default to magenta
|
|
72
|
+
color_name = highlight if isinstance(highlight, str) else "magenta"
|
|
73
|
+
highlight_color = self.HIGHLIGHT_COLORS.get(
|
|
74
|
+
color_name, self.HIGHLIGHT_COLORS["magenta"]
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Apply highlight to the message
|
|
78
|
+
original_message = record.getMessage()
|
|
79
|
+
record.msg = Colors.apply(original_message, highlight_color)
|
|
80
|
+
if record.args:
|
|
81
|
+
record.args = ()
|
|
82
|
+
|
|
83
|
+
return super().format(record)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class LoggerAdapter(logging.LoggerAdapter[logging.Logger]):
|
|
87
|
+
"""
|
|
88
|
+
Logger adapter that allows highlighting specific log messages.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def process(
|
|
92
|
+
self, msg: Any, kwargs: MutableMapping[str, Any]
|
|
93
|
+
) -> Tuple[Any, MutableMapping[str, Any]]:
|
|
94
|
+
# Pass highlight flag through to the record
|
|
95
|
+
return msg, kwargs
|
|
96
|
+
|
|
97
|
+
def highlight(
|
|
98
|
+
self,
|
|
99
|
+
msg: object,
|
|
100
|
+
*args: Any,
|
|
101
|
+
color: Literal["magenta", "cyan", "yellow", "red", "green"] = "magenta",
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Log a highlighted message.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
msg: The message format string
|
|
109
|
+
color: The color to highlight with (magenta, cyan, yellow, red)
|
|
110
|
+
*args: The args for the message format string
|
|
111
|
+
**kwargs: Additional logging kwargs
|
|
112
|
+
"""
|
|
113
|
+
kwargs.setdefault("extra", {})
|
|
114
|
+
if isinstance(kwargs["extra"], dict):
|
|
115
|
+
kwargs["extra"]["highlight"] = color
|
|
116
|
+
return self.info(msg, *args, **kwargs)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_logger(namespace: str) -> LoggerAdapter:
|
|
120
|
+
"""
|
|
121
|
+
Get a colored logger for the specified namespace.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
namespace: The namespace for the logger
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A configured logger instance with highlighting support
|
|
128
|
+
"""
|
|
129
|
+
logger = logging.getLogger(namespace)
|
|
130
|
+
|
|
131
|
+
# Only add handler if it doesn't exist
|
|
132
|
+
if not logger.handlers:
|
|
133
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
134
|
+
handler.setFormatter(ColoredFormatter())
|
|
135
|
+
logger.addHandler(handler)
|
|
136
|
+
|
|
137
|
+
# Set default level to INFO
|
|
138
|
+
logger.setLevel(logging.INFO)
|
|
139
|
+
|
|
140
|
+
# Wrap with adapter to support highlighting
|
|
141
|
+
return LoggerAdapter(logger, {})
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from docent.data_models.agent_run import AgentRun
|
|
2
|
+
from docent.data_models.citation import Citation
|
|
3
|
+
from docent.data_models.filters import (
|
|
4
|
+
AgentRunIdFilter,
|
|
5
|
+
BaseFrameFilter,
|
|
6
|
+
ComplexFilter,
|
|
7
|
+
SearchResultPredicateFilter,
|
|
8
|
+
)
|
|
9
|
+
from docent.data_models.metadata import BaseAgentRunMetadata, BaseMetadata, FrameDimension
|
|
10
|
+
from docent.data_models.regex import RegexSnippet
|
|
11
|
+
from docent.data_models.transcript import Transcript
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"AgentRun",
|
|
15
|
+
"Citation",
|
|
16
|
+
"RegexSnippet",
|
|
17
|
+
"AgentRunIdFilter",
|
|
18
|
+
"FrameDimension",
|
|
19
|
+
"BaseFrameFilter",
|
|
20
|
+
"SearchResultPredicateFilter",
|
|
21
|
+
"ComplexFilter",
|
|
22
|
+
"BaseAgentRunMetadata",
|
|
23
|
+
"BaseMetadata",
|
|
24
|
+
"Transcript",
|
|
25
|
+
]
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import tiktoken
|
|
2
|
+
|
|
3
|
+
MAX_TOKENS = 100_000
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_token_count(text: str, model: str = "gpt-4") -> int:
|
|
7
|
+
"""Get the number of tokens in a text under the GPT-4 tokenization scheme."""
|
|
8
|
+
encoding = tiktoken.encoding_for_model(model)
|
|
9
|
+
return len(encoding.encode(text))
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def truncate_to_token_limit(text: str, max_tokens: int, model: str = "gpt-4") -> str:
|
|
13
|
+
"""Truncate text to stay within the specified token limit."""
|
|
14
|
+
encoding = tiktoken.encoding_for_model(model)
|
|
15
|
+
tokens = encoding.encode(text)
|
|
16
|
+
|
|
17
|
+
if len(tokens) <= max_tokens:
|
|
18
|
+
return text
|
|
19
|
+
|
|
20
|
+
return encoding.decode(tokens[:max_tokens])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MessageRange:
|
|
24
|
+
"""A range of messages in a transcript. start is inclusive, end is exclusive."""
|
|
25
|
+
|
|
26
|
+
start: int
|
|
27
|
+
end: int
|
|
28
|
+
include_metadata: bool
|
|
29
|
+
num_tokens: int
|
|
30
|
+
|
|
31
|
+
def __init__(self, start: int, end: int, include_metadata: bool, num_tokens: int):
|
|
32
|
+
self.start = start
|
|
33
|
+
self.end = end
|
|
34
|
+
self.include_metadata = include_metadata
|
|
35
|
+
self.num_tokens = num_tokens
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def group_messages_into_ranges(
|
|
39
|
+
token_counts: list[int], metadata_tokens: int, max_tokens: int, margin: int = 50
|
|
40
|
+
) -> list[MessageRange]:
|
|
41
|
+
"""Split a list of messages + metadata into ranges that stay within the specified token limit.
|
|
42
|
+
|
|
43
|
+
Always tries to create ranges with metadata included, unless a single message + metadata is too long,
|
|
44
|
+
in which case you get a lone message with no metadata
|
|
45
|
+
"""
|
|
46
|
+
ranges: list[MessageRange] = []
|
|
47
|
+
start_index = 0
|
|
48
|
+
running_token_count = 0
|
|
49
|
+
|
|
50
|
+
i = 0
|
|
51
|
+
while i < len(token_counts):
|
|
52
|
+
new_token_count = token_counts[i]
|
|
53
|
+
if running_token_count + new_token_count + metadata_tokens > max_tokens - margin:
|
|
54
|
+
if start_index == i: # a single message + metadata is already too long
|
|
55
|
+
ranges.append(
|
|
56
|
+
MessageRange(
|
|
57
|
+
start=i, end=i + 1, include_metadata=False, num_tokens=new_token_count
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
i += 1
|
|
61
|
+
else:
|
|
62
|
+
# add all messages from start_index to i-1, with metadata included
|
|
63
|
+
ranges.append(
|
|
64
|
+
MessageRange(
|
|
65
|
+
start=start_index,
|
|
66
|
+
end=i,
|
|
67
|
+
include_metadata=True,
|
|
68
|
+
num_tokens=running_token_count + metadata_tokens,
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
running_token_count = 0
|
|
72
|
+
start_index = i
|
|
73
|
+
else:
|
|
74
|
+
running_token_count += new_token_count
|
|
75
|
+
i += 1
|
|
76
|
+
|
|
77
|
+
if running_token_count > 0:
|
|
78
|
+
include_metadata = running_token_count + metadata_tokens < max_tokens - margin
|
|
79
|
+
num_tokens = (
|
|
80
|
+
running_token_count + metadata_tokens if include_metadata else running_token_count
|
|
81
|
+
)
|
|
82
|
+
ranges.append(
|
|
83
|
+
MessageRange(
|
|
84
|
+
start=start_index,
|
|
85
|
+
end=len(token_counts),
|
|
86
|
+
include_metadata=include_metadata,
|
|
87
|
+
num_tokens=num_tokens,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return ranges
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any, Literal, TypedDict, cast
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from pydantic import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
Field,
|
|
9
|
+
field_serializer,
|
|
10
|
+
field_validator,
|
|
11
|
+
model_validator,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
|
|
15
|
+
from docent.data_models.metadata import BaseAgentRunMetadata
|
|
16
|
+
from docent.data_models.transcript import Transcript, TranscriptWithoutMetadataValidator
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FilterableField(TypedDict):
|
|
20
|
+
name: str
|
|
21
|
+
type: Literal["str", "bool", "int", "float"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AgentRun(BaseModel):
|
|
25
|
+
"""Represents a complete run of an agent with transcripts and metadata.
|
|
26
|
+
|
|
27
|
+
An AgentRun encapsulates the execution of an agent, storing all communication
|
|
28
|
+
transcripts and associated metadata. It must contain at least one transcript.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
id: Unique identifier for the agent run, auto-generated by default.
|
|
32
|
+
name: Optional human-readable name for the agent run.
|
|
33
|
+
description: Optional description of the agent run.
|
|
34
|
+
transcripts: Dict mapping transcript IDs to Transcript objects.
|
|
35
|
+
metadata: Additional structured metadata about the agent run.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
39
|
+
name: str | None = None
|
|
40
|
+
description: str | None = None
|
|
41
|
+
|
|
42
|
+
transcripts: dict[str, Transcript]
|
|
43
|
+
metadata: BaseAgentRunMetadata
|
|
44
|
+
|
|
45
|
+
@field_serializer("metadata")
|
|
46
|
+
def serialize_metadata(self, metadata: BaseAgentRunMetadata, _info: Any) -> dict[str, Any]:
|
|
47
|
+
"""
|
|
48
|
+
Custom serializer for the metadata field so the internal fields are explicitly preserved.
|
|
49
|
+
"""
|
|
50
|
+
return metadata.model_dump(strip_internal_fields=False)
|
|
51
|
+
|
|
52
|
+
@field_validator("metadata", mode="before")
|
|
53
|
+
@classmethod
|
|
54
|
+
def _validate_metadata_type(cls, v: Any) -> Any:
|
|
55
|
+
if v is not None and not isinstance(v, BaseAgentRunMetadata):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"metadata must be an instance of BaseAgentRunMetadata, got {type(v).__name__}"
|
|
58
|
+
)
|
|
59
|
+
return v
|
|
60
|
+
|
|
61
|
+
@model_validator(mode="after")
|
|
62
|
+
def _validate_transcripts_not_empty(self):
|
|
63
|
+
"""Validates that the agent run contains at least one transcript.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If the transcripts list is empty.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
AgentRun: The validated AgentRun instance.
|
|
70
|
+
"""
|
|
71
|
+
if len(self.transcripts) == 0:
|
|
72
|
+
raise ValueError("AgentRun must have at least one transcript")
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def to_text(self, token_limit: int = sys.maxsize) -> list[str]:
|
|
76
|
+
"""
|
|
77
|
+
Represents an agent run as a list of strings, each of which is at most token_limit tokens
|
|
78
|
+
under the GPT-4 tokenization scheme.
|
|
79
|
+
|
|
80
|
+
We'll try to split up long AgentRuns along transcript boundaries and include metadata.
|
|
81
|
+
For very long transcripts, we'll have to split them up further and remove metadata.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
transcript_strs: list[str] = [
|
|
85
|
+
f"<transcript {t_key}>\n{t.to_str(agent_run_idx=None, transcript_idx=i)}\n</transcript {t_key}>"
|
|
86
|
+
for i, (t_key, t) in enumerate(self.transcripts.items())
|
|
87
|
+
]
|
|
88
|
+
transcripts_str = "\n\n".join(transcript_strs)
|
|
89
|
+
|
|
90
|
+
# Gather metadata
|
|
91
|
+
metadata_obj = self.metadata.model_dump(strip_internal_fields=True)
|
|
92
|
+
if self.name is not None:
|
|
93
|
+
metadata_obj["name"] = self.name
|
|
94
|
+
if self.description is not None:
|
|
95
|
+
metadata_obj["description"] = self.description
|
|
96
|
+
# Add the field descriptions if they exist
|
|
97
|
+
metadata_obj = {
|
|
98
|
+
(f"{k} ({d})" if (d := self.metadata.get_field_description(k)) is not None else k): v
|
|
99
|
+
for k, v in metadata_obj.items()
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
yaml_width = float("inf")
|
|
103
|
+
transcripts_str = (
|
|
104
|
+
f"Here is a complete agent run for analysis purposes only:\n{transcripts_str}\n\n"
|
|
105
|
+
)
|
|
106
|
+
metadata_str = f"Metadata about the complete agent run:\n<agent run metadata>\n{yaml.dump(metadata_obj, width=yaml_width)}\n</agent run metadata>"
|
|
107
|
+
|
|
108
|
+
if token_limit == sys.maxsize:
|
|
109
|
+
return [f"{transcripts_str}" f"{metadata_str}"]
|
|
110
|
+
|
|
111
|
+
# Compute message length; if fits, return the full transcript and metadata
|
|
112
|
+
transcript_str_tokens = get_token_count(transcripts_str)
|
|
113
|
+
metadata_str_tokens = get_token_count(metadata_str)
|
|
114
|
+
if transcript_str_tokens + metadata_str_tokens <= token_limit:
|
|
115
|
+
return [f"{transcripts_str}" f"{metadata_str}"]
|
|
116
|
+
|
|
117
|
+
# Otherwise, split up the transcript and metadata into chunks
|
|
118
|
+
# TODO(vincent, mengk): does this code account for multiple transcripts correctly? a little confused.
|
|
119
|
+
else:
|
|
120
|
+
results: list[str] = []
|
|
121
|
+
transcript_token_counts = [get_token_count(t) for t in transcript_strs]
|
|
122
|
+
ranges = group_messages_into_ranges(
|
|
123
|
+
transcript_token_counts, metadata_str_tokens, token_limit - 50
|
|
124
|
+
)
|
|
125
|
+
for msg_range in ranges:
|
|
126
|
+
if msg_range.include_metadata:
|
|
127
|
+
cur_transcript_str = "\n\n".join(
|
|
128
|
+
transcript_strs[msg_range.start : msg_range.end]
|
|
129
|
+
)
|
|
130
|
+
results.append(
|
|
131
|
+
f"Here is a partial agent run for analysis purposes only:\n{cur_transcript_str}"
|
|
132
|
+
f"{metadata_str}"
|
|
133
|
+
)
|
|
134
|
+
else:
|
|
135
|
+
assert (
|
|
136
|
+
msg_range.end == msg_range.start + 1
|
|
137
|
+
), "Ranges without metadata should be a single message"
|
|
138
|
+
t_id, t = list(self.transcripts.items())[msg_range.start]
|
|
139
|
+
if msg_range.num_tokens < token_limit - 50:
|
|
140
|
+
transcript = f"<transcript {t_id}>\n{t.to_str()}\n</transcript {t_id}>"
|
|
141
|
+
result = (
|
|
142
|
+
f"Here is a partial agent run for analysis purposes only:\n{transcript}"
|
|
143
|
+
)
|
|
144
|
+
results.append(result)
|
|
145
|
+
else:
|
|
146
|
+
transcript_fragments = t.to_str_with_token_limit(token_limit - 50)
|
|
147
|
+
for fragment in transcript_fragments:
|
|
148
|
+
result = f"<transcript {t_id}>\n{fragment}\n</transcript {t_id}>"
|
|
149
|
+
result = (
|
|
150
|
+
f"Here is a partial agent run for analysis purposes only:\n{result}"
|
|
151
|
+
)
|
|
152
|
+
results.append(result)
|
|
153
|
+
return results
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def text(self) -> str:
|
|
157
|
+
"""Concatenates all transcript texts with double newlines as separators.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
str: A string representation of all transcripts.
|
|
161
|
+
"""
|
|
162
|
+
return self.to_text()[0]
|
|
163
|
+
|
|
164
|
+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
|
165
|
+
"""Extends the parent model_dump method to include the text property.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
*args: Variable length argument list passed to parent method.
|
|
169
|
+
**kwargs: Arbitrary keyword arguments passed to parent method.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
dict[str, Any]: Dictionary representation of the model including the text property.
|
|
173
|
+
"""
|
|
174
|
+
return super().model_dump(*args, **kwargs) | {"text": self.text}
|
|
175
|
+
|
|
176
|
+
def get_filterable_fields(self, max_depth: int = 1) -> list[FilterableField]:
|
|
177
|
+
"""Returns a list of all fields that can be used to filter the agent run,
|
|
178
|
+
by recursively exploring the model_dump() for singleton types in dictionaries.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
list[FilterableField]: A list of filterable fields, where each field is a
|
|
182
|
+
dictionary containing its 'name' (path) and 'type'.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
result: list[FilterableField] = []
|
|
186
|
+
|
|
187
|
+
def _explore_dict(d: dict[str, Any], prefix: str, depth: int):
|
|
188
|
+
nonlocal result
|
|
189
|
+
|
|
190
|
+
if depth > max_depth:
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
for k, v in d.items():
|
|
194
|
+
if isinstance(v, (str, int, float, bool)):
|
|
195
|
+
result.append(
|
|
196
|
+
{
|
|
197
|
+
"name": f"{prefix}.{k}",
|
|
198
|
+
"type": cast(Literal["str", "bool", "int", "float"], type(v).__name__),
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
elif isinstance(v, dict):
|
|
202
|
+
_explore_dict(cast(dict[str, Any], v), f"{prefix}.{k}", depth + 1)
|
|
203
|
+
|
|
204
|
+
# Look at the agent run metadata
|
|
205
|
+
_explore_dict(self.metadata.model_dump(strip_internal_fields=True), "metadata", 0)
|
|
206
|
+
# Look at the transcript metadata
|
|
207
|
+
# TODO(mengk): restore this later when we have the ability to integrate with SQL.
|
|
208
|
+
# for t_id, t in self.transcripts.items():
|
|
209
|
+
# _explore_dict(
|
|
210
|
+
# t.metadata.model_dump(strip_internal_fields=True), f"transcript.{t_id}.metadata", 0
|
|
211
|
+
# )
|
|
212
|
+
|
|
213
|
+
# Append the text field
|
|
214
|
+
result.append({"name": "text", "type": "str"})
|
|
215
|
+
|
|
216
|
+
return result
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class AgentRunWithoutMetadataValidator(AgentRun):
|
|
220
|
+
"""
|
|
221
|
+
A version of AgentRun that doesn't have the model_validator on metadata.
|
|
222
|
+
Needed for sending/receiving agent runs via JSON, since they incorrectly trip the existing model_validator.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
transcripts: dict[str, TranscriptWithoutMetadataValidator] # type: ignore
|
|
226
|
+
|
|
227
|
+
@field_validator("metadata", mode="before")
|
|
228
|
+
@classmethod
|
|
229
|
+
def _validate_metadata_type(cls, v: Any) -> Any:
|
|
230
|
+
# Bypass the model_validator
|
|
231
|
+
return v
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from docent.data_models.chat.content import Content, ContentReasoning, ContentText
|
|
2
|
+
from docent.data_models.chat.message import (
|
|
3
|
+
AssistantMessage,
|
|
4
|
+
ChatMessage,
|
|
5
|
+
SystemMessage,
|
|
6
|
+
ToolMessage,
|
|
7
|
+
UserMessage,
|
|
8
|
+
parse_chat_message,
|
|
9
|
+
)
|
|
10
|
+
from docent.data_models.chat.tool import ToolCall, ToolInfo, ToolParams
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ChatMessage",
|
|
14
|
+
"AssistantMessage",
|
|
15
|
+
"SystemMessage",
|
|
16
|
+
"ToolMessage",
|
|
17
|
+
"UserMessage",
|
|
18
|
+
"Content",
|
|
19
|
+
"ContentReasoning",
|
|
20
|
+
"ContentText",
|
|
21
|
+
"ToolCall",
|
|
22
|
+
"ToolInfo",
|
|
23
|
+
"ToolParams",
|
|
24
|
+
"parse_chat_message",
|
|
25
|
+
]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Discriminator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseContent(BaseModel):
|
|
7
|
+
"""Base class for all content types in chat messages.
|
|
8
|
+
|
|
9
|
+
Provides the foundation for different content types with a discriminator field.
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
type: The content type identifier, used for discriminating between content types.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
type: Literal["text", "reasoning", "image", "audio", "video"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ContentText(BaseContent):
|
|
19
|
+
"""Text content for chat messages.
|
|
20
|
+
|
|
21
|
+
Represents plain text content in a chat message.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
type: Fixed as "text" to identify this content type.
|
|
25
|
+
text: The actual text content.
|
|
26
|
+
refusal: Optional flag indicating if this is a refusal message.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
type: Literal["text"] = "text" # type: ignore
|
|
30
|
+
text: str
|
|
31
|
+
refusal: bool | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ContentReasoning(BaseContent):
|
|
35
|
+
"""Reasoning content for chat messages.
|
|
36
|
+
|
|
37
|
+
Represents reasoning or thought process content in a chat message.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
type: Fixed as "reasoning" to identify this content type.
|
|
41
|
+
reasoning: The actual reasoning text.
|
|
42
|
+
signature: Optional signature associated with the reasoning.
|
|
43
|
+
redacted: Flag indicating if the reasoning has been redacted.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
type: Literal["reasoning"] = "reasoning" # type: ignore
|
|
47
|
+
reasoning: str
|
|
48
|
+
signature: str | None = None
|
|
49
|
+
redacted: bool = False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Content type discriminated union
|
|
53
|
+
Content = Annotated[ContentText | ContentReasoning, Discriminator("type")]
|
|
54
|
+
"""Discriminated union of possible content types using the 'type' field.
|
|
55
|
+
Can be either ContentText or ContentReasoning.
|
|
56
|
+
"""
|