docent-python 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

docent/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ __all__ = ["Docent"]
2
+
3
+ from docent.sdk.client import Docent
@@ -0,0 +1,3 @@
1
+ __all__ = ["get_logger"]
2
+
3
+ from docent._log_util.logger import get_logger
@@ -0,0 +1,141 @@
1
+ import logging
2
+ import sys
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Literal, MutableMapping, Optional, Tuple
5
+
6
+
7
+ @dataclass
8
+ class ColorCode:
9
+ fore: str
10
+ style: str = ""
11
+
12
+
13
+ class Colors:
14
+ # Foreground colors
15
+ BLACK = ColorCode("\033[30m")
16
+ RED = ColorCode("\033[31m")
17
+ GREEN = ColorCode("\033[32m")
18
+ YELLOW = ColorCode("\033[33m")
19
+ BLUE = ColorCode("\033[34m")
20
+ MAGENTA = ColorCode("\033[35m")
21
+ CYAN = ColorCode("\033[36m")
22
+ WHITE = ColorCode("\033[37m")
23
+ BRIGHT_MAGENTA = ColorCode("\033[95m")
24
+ BRIGHT_CYAN = ColorCode("\033[96m")
25
+
26
+ # Styles
27
+ BOLD = "\033[1m"
28
+ RESET = "\033[0m"
29
+
30
+ @staticmethod
31
+ def apply(text: str, color: ColorCode) -> str:
32
+ return f"{color.style}{color.fore}{text}{Colors.RESET}"
33
+
34
+
35
+ class ColoredFormatter(logging.Formatter):
36
+ COLORS: Dict[int, ColorCode] = {
37
+ logging.DEBUG: Colors.BLUE,
38
+ logging.INFO: Colors.GREEN,
39
+ logging.WARNING: Colors.YELLOW,
40
+ logging.ERROR: Colors.RED,
41
+ logging.CRITICAL: ColorCode("\033[31m", Colors.BOLD),
42
+ }
43
+
44
+ # Available highlight colors
45
+ HIGHLIGHT_COLORS: Dict[str, ColorCode] = {
46
+ "magenta": ColorCode(Colors.BRIGHT_MAGENTA.fore, Colors.BOLD),
47
+ "cyan": ColorCode(Colors.BRIGHT_CYAN.fore, Colors.BOLD),
48
+ "yellow": ColorCode(Colors.YELLOW.fore, Colors.BOLD),
49
+ "red": ColorCode(Colors.RED.fore, Colors.BOLD),
50
+ }
51
+
52
+ def __init__(self, fmt: Optional[str] = None) -> None:
53
+ super().__init__(
54
+ fmt or "%(asctime)s [%(levelname)s] %(namespace)s: %(message)s", datefmt="%H:%M:%S"
55
+ )
56
+
57
+ def format(self, record: logging.LogRecord) -> str:
58
+ # Add namespace to extra fields if not present
59
+ if not getattr(record, "namespace", None):
60
+ record.__dict__["namespace"] = record.name
61
+
62
+ # Color the level name
63
+ record.levelname = Colors.apply(record.levelname, self.COLORS[record.levelno])
64
+
65
+ # Color the namespace
66
+ record.__dict__["namespace"] = Colors.apply(record.__dict__["namespace"], Colors.CYAN)
67
+
68
+ # Check if highlight flag is set
69
+ highlight = getattr(record, "highlight", None)
70
+ if highlight:
71
+ # Get the highlight color or default to magenta
72
+ color_name = highlight if isinstance(highlight, str) else "magenta"
73
+ highlight_color = self.HIGHLIGHT_COLORS.get(
74
+ color_name, self.HIGHLIGHT_COLORS["magenta"]
75
+ )
76
+
77
+ # Apply highlight to the message
78
+ original_message = record.getMessage()
79
+ record.msg = Colors.apply(original_message, highlight_color)
80
+ if record.args:
81
+ record.args = ()
82
+
83
+ return super().format(record)
84
+
85
+
86
+ class LoggerAdapter(logging.LoggerAdapter[logging.Logger]):
87
+ """
88
+ Logger adapter that allows highlighting specific log messages.
89
+ """
90
+
91
+ def process(
92
+ self, msg: Any, kwargs: MutableMapping[str, Any]
93
+ ) -> Tuple[Any, MutableMapping[str, Any]]:
94
+ # Pass highlight flag through to the record
95
+ return msg, kwargs
96
+
97
+ def highlight(
98
+ self,
99
+ msg: object,
100
+ *args: Any,
101
+ color: Literal["magenta", "cyan", "yellow", "red", "green"] = "magenta",
102
+ **kwargs: Any,
103
+ ) -> None:
104
+ """
105
+ Log a highlighted message.
106
+
107
+ Args:
108
+ msg: The message format string
109
+ color: The color to highlight with (magenta, cyan, yellow, red)
110
+ *args: The args for the message format string
111
+ **kwargs: Additional logging kwargs
112
+ """
113
+ kwargs.setdefault("extra", {})
114
+ if isinstance(kwargs["extra"], dict):
115
+ kwargs["extra"]["highlight"] = color
116
+ return self.info(msg, *args, **kwargs)
117
+
118
+
119
+ def get_logger(namespace: str) -> LoggerAdapter:
120
+ """
121
+ Get a colored logger for the specified namespace.
122
+
123
+ Args:
124
+ namespace: The namespace for the logger
125
+
126
+ Returns:
127
+ A configured logger instance with highlighting support
128
+ """
129
+ logger = logging.getLogger(namespace)
130
+
131
+ # Only add handler if it doesn't exist
132
+ if not logger.handlers:
133
+ handler = logging.StreamHandler(sys.stdout)
134
+ handler.setFormatter(ColoredFormatter())
135
+ logger.addHandler(handler)
136
+
137
+ # Set default level to INFO
138
+ logger.setLevel(logging.INFO)
139
+
140
+ # Wrap with adapter to support highlighting
141
+ return LoggerAdapter(logger, {})
@@ -0,0 +1,25 @@
1
+ from docent.data_models.agent_run import AgentRun
2
+ from docent.data_models.citation import Citation
3
+ from docent.data_models.filters import (
4
+ AgentRunIdFilter,
5
+ BaseFrameFilter,
6
+ ComplexFilter,
7
+ SearchResultPredicateFilter,
8
+ )
9
+ from docent.data_models.metadata import BaseAgentRunMetadata, BaseMetadata, FrameDimension
10
+ from docent.data_models.regex import RegexSnippet
11
+ from docent.data_models.transcript import Transcript
12
+
13
+ __all__ = [
14
+ "AgentRun",
15
+ "Citation",
16
+ "RegexSnippet",
17
+ "AgentRunIdFilter",
18
+ "FrameDimension",
19
+ "BaseFrameFilter",
20
+ "SearchResultPredicateFilter",
21
+ "ComplexFilter",
22
+ "BaseAgentRunMetadata",
23
+ "BaseMetadata",
24
+ "Transcript",
25
+ ]
@@ -0,0 +1,91 @@
1
+ import tiktoken
2
+
3
+ MAX_TOKENS = 100_000
4
+
5
+
6
+ def get_token_count(text: str, model: str = "gpt-4") -> int:
7
+ """Get the number of tokens in a text under the GPT-4 tokenization scheme."""
8
+ encoding = tiktoken.encoding_for_model(model)
9
+ return len(encoding.encode(text))
10
+
11
+
12
+ def truncate_to_token_limit(text: str, max_tokens: int, model: str = "gpt-4") -> str:
13
+ """Truncate text to stay within the specified token limit."""
14
+ encoding = tiktoken.encoding_for_model(model)
15
+ tokens = encoding.encode(text)
16
+
17
+ if len(tokens) <= max_tokens:
18
+ return text
19
+
20
+ return encoding.decode(tokens[:max_tokens])
21
+
22
+
23
+ class MessageRange:
24
+ """A range of messages in a transcript. start is inclusive, end is exclusive."""
25
+
26
+ start: int
27
+ end: int
28
+ include_metadata: bool
29
+ num_tokens: int
30
+
31
+ def __init__(self, start: int, end: int, include_metadata: bool, num_tokens: int):
32
+ self.start = start
33
+ self.end = end
34
+ self.include_metadata = include_metadata
35
+ self.num_tokens = num_tokens
36
+
37
+
38
+ def group_messages_into_ranges(
39
+ token_counts: list[int], metadata_tokens: int, max_tokens: int, margin: int = 50
40
+ ) -> list[MessageRange]:
41
+ """Split a list of messages + metadata into ranges that stay within the specified token limit.
42
+
43
+ Always tries to create ranges with metadata included, unless a single message + metadata is too long,
44
+ in which case you get a lone message with no metadata
45
+ """
46
+ ranges: list[MessageRange] = []
47
+ start_index = 0
48
+ running_token_count = 0
49
+
50
+ i = 0
51
+ while i < len(token_counts):
52
+ new_token_count = token_counts[i]
53
+ if running_token_count + new_token_count + metadata_tokens > max_tokens - margin:
54
+ if start_index == i: # a single message + metadata is already too long
55
+ ranges.append(
56
+ MessageRange(
57
+ start=i, end=i + 1, include_metadata=False, num_tokens=new_token_count
58
+ )
59
+ )
60
+ i += 1
61
+ else:
62
+ # add all messages from start_index to i-1, with metadata included
63
+ ranges.append(
64
+ MessageRange(
65
+ start=start_index,
66
+ end=i,
67
+ include_metadata=True,
68
+ num_tokens=running_token_count + metadata_tokens,
69
+ )
70
+ )
71
+ running_token_count = 0
72
+ start_index = i
73
+ else:
74
+ running_token_count += new_token_count
75
+ i += 1
76
+
77
+ if running_token_count > 0:
78
+ include_metadata = running_token_count + metadata_tokens < max_tokens - margin
79
+ num_tokens = (
80
+ running_token_count + metadata_tokens if include_metadata else running_token_count
81
+ )
82
+ ranges.append(
83
+ MessageRange(
84
+ start=start_index,
85
+ end=len(token_counts),
86
+ include_metadata=include_metadata,
87
+ num_tokens=num_tokens,
88
+ )
89
+ )
90
+
91
+ return ranges
@@ -0,0 +1,231 @@
1
+ import sys
2
+ from typing import Any, Literal, TypedDict, cast
3
+ from uuid import uuid4
4
+
5
+ import yaml
6
+ from pydantic import (
7
+ BaseModel,
8
+ Field,
9
+ field_serializer,
10
+ field_validator,
11
+ model_validator,
12
+ )
13
+
14
+ from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
15
+ from docent.data_models.metadata import BaseAgentRunMetadata
16
+ from docent.data_models.transcript import Transcript, TranscriptWithoutMetadataValidator
17
+
18
+
19
+ class FilterableField(TypedDict):
20
+ name: str
21
+ type: Literal["str", "bool", "int", "float"]
22
+
23
+
24
+ class AgentRun(BaseModel):
25
+ """Represents a complete run of an agent with transcripts and metadata.
26
+
27
+ An AgentRun encapsulates the execution of an agent, storing all communication
28
+ transcripts and associated metadata. It must contain at least one transcript.
29
+
30
+ Attributes:
31
+ id: Unique identifier for the agent run, auto-generated by default.
32
+ name: Optional human-readable name for the agent run.
33
+ description: Optional description of the agent run.
34
+ transcripts: Dict mapping transcript IDs to Transcript objects.
35
+ metadata: Additional structured metadata about the agent run.
36
+ """
37
+
38
+ id: str = Field(default_factory=lambda: str(uuid4()))
39
+ name: str | None = None
40
+ description: str | None = None
41
+
42
+ transcripts: dict[str, Transcript]
43
+ metadata: BaseAgentRunMetadata
44
+
45
+ @field_serializer("metadata")
46
+ def serialize_metadata(self, metadata: BaseAgentRunMetadata, _info: Any) -> dict[str, Any]:
47
+ """
48
+ Custom serializer for the metadata field so the internal fields are explicitly preserved.
49
+ """
50
+ return metadata.model_dump(strip_internal_fields=False)
51
+
52
+ @field_validator("metadata", mode="before")
53
+ @classmethod
54
+ def _validate_metadata_type(cls, v: Any) -> Any:
55
+ if v is not None and not isinstance(v, BaseAgentRunMetadata):
56
+ raise ValueError(
57
+ f"metadata must be an instance of BaseAgentRunMetadata, got {type(v).__name__}"
58
+ )
59
+ return v
60
+
61
+ @model_validator(mode="after")
62
+ def _validate_transcripts_not_empty(self):
63
+ """Validates that the agent run contains at least one transcript.
64
+
65
+ Raises:
66
+ ValueError: If the transcripts list is empty.
67
+
68
+ Returns:
69
+ AgentRun: The validated AgentRun instance.
70
+ """
71
+ if len(self.transcripts) == 0:
72
+ raise ValueError("AgentRun must have at least one transcript")
73
+ return self
74
+
75
+ def to_text(self, token_limit: int = sys.maxsize) -> list[str]:
76
+ """
77
+ Represents an agent run as a list of strings, each of which is at most token_limit tokens
78
+ under the GPT-4 tokenization scheme.
79
+
80
+ We'll try to split up long AgentRuns along transcript boundaries and include metadata.
81
+ For very long transcripts, we'll have to split them up further and remove metadata.
82
+ """
83
+
84
+ transcript_strs: list[str] = [
85
+ f"<transcript {t_key}>\n{t.to_str(agent_run_idx=None, transcript_idx=i)}\n</transcript {t_key}>"
86
+ for i, (t_key, t) in enumerate(self.transcripts.items())
87
+ ]
88
+ transcripts_str = "\n\n".join(transcript_strs)
89
+
90
+ # Gather metadata
91
+ metadata_obj = self.metadata.model_dump(strip_internal_fields=True)
92
+ if self.name is not None:
93
+ metadata_obj["name"] = self.name
94
+ if self.description is not None:
95
+ metadata_obj["description"] = self.description
96
+ # Add the field descriptions if they exist
97
+ metadata_obj = {
98
+ (f"{k} ({d})" if (d := self.metadata.get_field_description(k)) is not None else k): v
99
+ for k, v in metadata_obj.items()
100
+ }
101
+
102
+ yaml_width = float("inf")
103
+ transcripts_str = (
104
+ f"Here is a complete agent run for analysis purposes only:\n{transcripts_str}\n\n"
105
+ )
106
+ metadata_str = f"Metadata about the complete agent run:\n<agent run metadata>\n{yaml.dump(metadata_obj, width=yaml_width)}\n</agent run metadata>"
107
+
108
+ if token_limit == sys.maxsize:
109
+ return [f"{transcripts_str}" f"{metadata_str}"]
110
+
111
+ # Compute message length; if fits, return the full transcript and metadata
112
+ transcript_str_tokens = get_token_count(transcripts_str)
113
+ metadata_str_tokens = get_token_count(metadata_str)
114
+ if transcript_str_tokens + metadata_str_tokens <= token_limit:
115
+ return [f"{transcripts_str}" f"{metadata_str}"]
116
+
117
+ # Otherwise, split up the transcript and metadata into chunks
118
+ # TODO(vincent, mengk): does this code account for multiple transcripts correctly? a little confused.
119
+ else:
120
+ results: list[str] = []
121
+ transcript_token_counts = [get_token_count(t) for t in transcript_strs]
122
+ ranges = group_messages_into_ranges(
123
+ transcript_token_counts, metadata_str_tokens, token_limit - 50
124
+ )
125
+ for msg_range in ranges:
126
+ if msg_range.include_metadata:
127
+ cur_transcript_str = "\n\n".join(
128
+ transcript_strs[msg_range.start : msg_range.end]
129
+ )
130
+ results.append(
131
+ f"Here is a partial agent run for analysis purposes only:\n{cur_transcript_str}"
132
+ f"{metadata_str}"
133
+ )
134
+ else:
135
+ assert (
136
+ msg_range.end == msg_range.start + 1
137
+ ), "Ranges without metadata should be a single message"
138
+ t_id, t = list(self.transcripts.items())[msg_range.start]
139
+ if msg_range.num_tokens < token_limit - 50:
140
+ transcript = f"<transcript {t_id}>\n{t.to_str()}\n</transcript {t_id}>"
141
+ result = (
142
+ f"Here is a partial agent run for analysis purposes only:\n{transcript}"
143
+ )
144
+ results.append(result)
145
+ else:
146
+ transcript_fragments = t.to_str_with_token_limit(token_limit - 50)
147
+ for fragment in transcript_fragments:
148
+ result = f"<transcript {t_id}>\n{fragment}\n</transcript {t_id}>"
149
+ result = (
150
+ f"Here is a partial agent run for analysis purposes only:\n{result}"
151
+ )
152
+ results.append(result)
153
+ return results
154
+
155
+ @property
156
+ def text(self) -> str:
157
+ """Concatenates all transcript texts with double newlines as separators.
158
+
159
+ Returns:
160
+ str: A string representation of all transcripts.
161
+ """
162
+ return self.to_text()[0]
163
+
164
+ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
165
+ """Extends the parent model_dump method to include the text property.
166
+
167
+ Args:
168
+ *args: Variable length argument list passed to parent method.
169
+ **kwargs: Arbitrary keyword arguments passed to parent method.
170
+
171
+ Returns:
172
+ dict[str, Any]: Dictionary representation of the model including the text property.
173
+ """
174
+ return super().model_dump(*args, **kwargs) | {"text": self.text}
175
+
176
+ def get_filterable_fields(self, max_depth: int = 1) -> list[FilterableField]:
177
+ """Returns a list of all fields that can be used to filter the agent run,
178
+ by recursively exploring the model_dump() for singleton types in dictionaries.
179
+
180
+ Returns:
181
+ list[FilterableField]: A list of filterable fields, where each field is a
182
+ dictionary containing its 'name' (path) and 'type'.
183
+ """
184
+
185
+ result: list[FilterableField] = []
186
+
187
+ def _explore_dict(d: dict[str, Any], prefix: str, depth: int):
188
+ nonlocal result
189
+
190
+ if depth > max_depth:
191
+ return
192
+
193
+ for k, v in d.items():
194
+ if isinstance(v, (str, int, float, bool)):
195
+ result.append(
196
+ {
197
+ "name": f"{prefix}.{k}",
198
+ "type": cast(Literal["str", "bool", "int", "float"], type(v).__name__),
199
+ }
200
+ )
201
+ elif isinstance(v, dict):
202
+ _explore_dict(cast(dict[str, Any], v), f"{prefix}.{k}", depth + 1)
203
+
204
+ # Look at the agent run metadata
205
+ _explore_dict(self.metadata.model_dump(strip_internal_fields=True), "metadata", 0)
206
+ # Look at the transcript metadata
207
+ # TODO(mengk): restore this later when we have the ability to integrate with SQL.
208
+ # for t_id, t in self.transcripts.items():
209
+ # _explore_dict(
210
+ # t.metadata.model_dump(strip_internal_fields=True), f"transcript.{t_id}.metadata", 0
211
+ # )
212
+
213
+ # Append the text field
214
+ result.append({"name": "text", "type": "str"})
215
+
216
+ return result
217
+
218
+
219
+ class AgentRunWithoutMetadataValidator(AgentRun):
220
+ """
221
+ A version of AgentRun that doesn't have the model_validator on metadata.
222
+ Needed for sending/receiving agent runs via JSON, since they incorrectly trip the existing model_validator.
223
+ """
224
+
225
+ transcripts: dict[str, TranscriptWithoutMetadataValidator] # type: ignore
226
+
227
+ @field_validator("metadata", mode="before")
228
+ @classmethod
229
+ def _validate_metadata_type(cls, v: Any) -> Any:
230
+ # Bypass the model_validator
231
+ return v
@@ -0,0 +1,25 @@
1
+ from docent.data_models.chat.content import Content, ContentReasoning, ContentText
2
+ from docent.data_models.chat.message import (
3
+ AssistantMessage,
4
+ ChatMessage,
5
+ SystemMessage,
6
+ ToolMessage,
7
+ UserMessage,
8
+ parse_chat_message,
9
+ )
10
+ from docent.data_models.chat.tool import ToolCall, ToolInfo, ToolParams
11
+
12
+ __all__ = [
13
+ "ChatMessage",
14
+ "AssistantMessage",
15
+ "SystemMessage",
16
+ "ToolMessage",
17
+ "UserMessage",
18
+ "Content",
19
+ "ContentReasoning",
20
+ "ContentText",
21
+ "ToolCall",
22
+ "ToolInfo",
23
+ "ToolParams",
24
+ "parse_chat_message",
25
+ ]
@@ -0,0 +1,56 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Discriminator
4
+
5
+
6
+ class BaseContent(BaseModel):
7
+ """Base class for all content types in chat messages.
8
+
9
+ Provides the foundation for different content types with a discriminator field.
10
+
11
+ Attributes:
12
+ type: The content type identifier, used for discriminating between content types.
13
+ """
14
+
15
+ type: Literal["text", "reasoning", "image", "audio", "video"]
16
+
17
+
18
+ class ContentText(BaseContent):
19
+ """Text content for chat messages.
20
+
21
+ Represents plain text content in a chat message.
22
+
23
+ Attributes:
24
+ type: Fixed as "text" to identify this content type.
25
+ text: The actual text content.
26
+ refusal: Optional flag indicating if this is a refusal message.
27
+ """
28
+
29
+ type: Literal["text"] = "text" # type: ignore
30
+ text: str
31
+ refusal: bool | None = None
32
+
33
+
34
+ class ContentReasoning(BaseContent):
35
+ """Reasoning content for chat messages.
36
+
37
+ Represents reasoning or thought process content in a chat message.
38
+
39
+ Attributes:
40
+ type: Fixed as "reasoning" to identify this content type.
41
+ reasoning: The actual reasoning text.
42
+ signature: Optional signature associated with the reasoning.
43
+ redacted: Flag indicating if the reasoning has been redacted.
44
+ """
45
+
46
+ type: Literal["reasoning"] = "reasoning" # type: ignore
47
+ reasoning: str
48
+ signature: str | None = None
49
+ redacted: bool = False
50
+
51
+
52
+ # Content type discriminated union
53
+ Content = Annotated[ContentText | ContentReasoning, Discriminator("type")]
54
+ """Discriminated union of possible content types using the 'type' field.
55
+ Can be either ContentText or ContentReasoning.
56
+ """