docent-python 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/__init__.py +3 -0
- docent/_log_util/__init__.py +3 -0
- docent/_log_util/logger.py +141 -0
- docent/data_models/__init__.py +25 -0
- docent/data_models/_tiktoken_util.py +91 -0
- docent/data_models/agent_run.py +231 -0
- docent/data_models/chat/__init__.py +25 -0
- docent/data_models/chat/content.py +56 -0
- docent/data_models/chat/message.py +125 -0
- docent/data_models/chat/tool.py +109 -0
- docent/data_models/citation.py +223 -0
- docent/data_models/filters.py +205 -0
- docent/data_models/metadata.py +219 -0
- docent/data_models/regex.py +56 -0
- docent/data_models/shared_types.py +10 -0
- docent/data_models/transcript.py +347 -0
- docent/py.typed +0 -0
- docent/sdk/__init__.py +0 -0
- docent/sdk/client.py +285 -0
- docent_python-0.1.0a1.dist-info/METADATA +16 -0
- docent_python-0.1.0a1.dist-info/RECORD +23 -0
- docent_python-0.1.0a1.dist-info/WHEEL +4 -0
- docent_python-0.1.0a1.dist-info/licenses/LICENSE.md +7 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Annotated, Any, Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Discriminator
|
|
5
|
+
|
|
6
|
+
from docent.data_models.chat.content import Content
|
|
7
|
+
from docent.data_models.chat.tool import ToolCall
|
|
8
|
+
|
|
9
|
+
logger = getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseChatMessage(BaseModel):
|
|
13
|
+
"""Base class for all chat message types.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
id: Optional unique identifier for the message.
|
|
17
|
+
content: The message content, either as a string or list of Content objects.
|
|
18
|
+
role: The role of the message sender (system, user, assistant, tool).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
id: str | None = None
|
|
22
|
+
content: str | list[Content]
|
|
23
|
+
role: Literal["system", "user", "assistant", "tool"]
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def text(self) -> str:
|
|
27
|
+
"""Get the text content of the message.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
str: The text content of the message. If content is a list,
|
|
31
|
+
concatenates all text content elements with newlines.
|
|
32
|
+
"""
|
|
33
|
+
if isinstance(self.content, str):
|
|
34
|
+
return self.content
|
|
35
|
+
else:
|
|
36
|
+
all_text = [content.text for content in self.content if content.type == "text"]
|
|
37
|
+
return "\n".join(all_text)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SystemMessage(BaseChatMessage):
|
|
41
|
+
"""System message in a chat conversation.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
role: Always set to "system".
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
role: Literal["system"] = "system" # type: ignore
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class UserMessage(BaseChatMessage):
|
|
51
|
+
"""User message in a chat conversation.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
role: Always set to "user".
|
|
55
|
+
tool_call_id: Optional list of tool call IDs this message is responding to.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
role: Literal["user"] = "user" # type: ignore
|
|
59
|
+
tool_call_id: list[str] | None = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class AssistantMessage(BaseChatMessage):
|
|
63
|
+
"""Assistant message in a chat conversation.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
role: Always set to "assistant".
|
|
67
|
+
model: Optional identifier for the model that generated this message.
|
|
68
|
+
tool_calls: Optional list of tool calls made by the assistant.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
role: Literal["assistant"] = "assistant" # type: ignore
|
|
72
|
+
model: str | None = None
|
|
73
|
+
tool_calls: list[ToolCall] | None = None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ToolMessage(BaseChatMessage):
|
|
77
|
+
"""Tool message in a chat conversation.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
role: Always set to "tool".
|
|
81
|
+
tool_call_id: Optional ID of the tool call this message is responding to.
|
|
82
|
+
function: Optional name of the function that was called.
|
|
83
|
+
error: Optional error information if the tool call failed.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
role: Literal["tool"] = "tool" # type: ignore
|
|
87
|
+
|
|
88
|
+
tool_call_id: str | None = None
|
|
89
|
+
function: str | None = None
|
|
90
|
+
error: dict[str, Any] | None = None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
ChatMessage = Annotated[
|
|
94
|
+
SystemMessage | UserMessage | AssistantMessage | ToolMessage,
|
|
95
|
+
Discriminator("role"),
|
|
96
|
+
]
|
|
97
|
+
"""Type alias for any chat message type, discriminated by the role field."""
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def parse_chat_message(message_data: dict[str, Any] | ChatMessage) -> ChatMessage:
|
|
101
|
+
"""Parse a message dictionary or object into the appropriate ChatMessage subclass.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
message_data: A dictionary or ChatMessage object representing a chat message.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
ChatMessage: An instance of a ChatMessage subclass based on the role.
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
ValueError: If the message role is unknown.
|
|
111
|
+
"""
|
|
112
|
+
if isinstance(message_data, (SystemMessage, UserMessage, AssistantMessage, ToolMessage)):
|
|
113
|
+
return message_data
|
|
114
|
+
|
|
115
|
+
role = message_data.get("role")
|
|
116
|
+
if role == "system":
|
|
117
|
+
return SystemMessage.model_validate(message_data)
|
|
118
|
+
elif role == "user":
|
|
119
|
+
return UserMessage.model_validate(message_data)
|
|
120
|
+
elif role == "assistant":
|
|
121
|
+
return AssistantMessage.model_validate(message_data)
|
|
122
|
+
elif role == "tool":
|
|
123
|
+
return ToolMessage.model_validate(message_data)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unknown message role: {role}")
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ToolCall:
|
|
11
|
+
"""Tool call information.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
id: Unique identifier for tool call.
|
|
15
|
+
type: Type of tool call. Can only be "function" or None.
|
|
16
|
+
function: Function called.
|
|
17
|
+
arguments: Arguments to function.
|
|
18
|
+
parse_error: Error which occurred parsing tool call.
|
|
19
|
+
view: Custom view of tool call input.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
id: str
|
|
23
|
+
type: Literal["function"] | None
|
|
24
|
+
function: str
|
|
25
|
+
arguments: dict[str, Any]
|
|
26
|
+
parse_error: str | None = None
|
|
27
|
+
view: ToolCallContent | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ToolCallContent(BaseModel):
|
|
31
|
+
"""Content to include in tool call view.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
title: Optional (plain text) title for tool call content.
|
|
35
|
+
format: Format (text or markdown).
|
|
36
|
+
content: Text or markdown content.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
title: str | None = None
|
|
40
|
+
format: Literal["text", "markdown"]
|
|
41
|
+
content: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ToolParam(BaseModel):
|
|
45
|
+
"""A parameter for a tool function.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name: The name of the parameter.
|
|
49
|
+
description: A description of what the parameter does.
|
|
50
|
+
input_schema: JSON Schema describing the parameter's type and validation rules.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
name: str
|
|
54
|
+
description: str
|
|
55
|
+
input_schema: dict[str, Any]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ToolParams(BaseModel):
|
|
59
|
+
"""Description of tool parameters object in JSON Schema format.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
type: The type of the parameters object, always 'object'.
|
|
63
|
+
properties: Dictionary mapping parameter names to their ToolParam definitions.
|
|
64
|
+
required: List of required parameter names.
|
|
65
|
+
additionalProperties: Whether additional properties are allowed beyond those
|
|
66
|
+
specified. Always False.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
type: Literal["object"] = "object"
|
|
70
|
+
properties: dict[str, ToolParam] = Field(default_factory=dict)
|
|
71
|
+
required: list[str] = Field(default_factory=list)
|
|
72
|
+
additionalProperties: bool = False
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ToolInfo(BaseModel):
|
|
76
|
+
"""Specification of a tool (JSON Schema compatible).
|
|
77
|
+
|
|
78
|
+
If you are implementing a ModelAPI, most LLM libraries can
|
|
79
|
+
be passed this object (dumped to a dict) directly as a function
|
|
80
|
+
specification. For example, in the OpenAI provider:
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
ChatCompletionToolParam(
|
|
84
|
+
type="function",
|
|
85
|
+
function=tool.model_dump(exclude_none=True),
|
|
86
|
+
)
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
In some cases the field names don't match up exactly. In that case
|
|
90
|
+
call `model_dump()` on the `parameters` field. For example, in the
|
|
91
|
+
Anthropic provider:
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
ToolParam(
|
|
95
|
+
name=tool.name,
|
|
96
|
+
description=tool.description,
|
|
97
|
+
input_schema=tool.parameters.model_dump(exclude_none=True),
|
|
98
|
+
)
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
name: Name of tool.
|
|
103
|
+
description: Short description of tool.
|
|
104
|
+
parameters: JSON Schema of tool parameters object.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
name: str
|
|
108
|
+
description: str
|
|
109
|
+
parameters: ToolParams = Field(default_factory=ToolParams)
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import TypedDict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Citation(TypedDict):
|
|
6
|
+
start_idx: int
|
|
7
|
+
end_idx: int
|
|
8
|
+
agent_run_idx: int | None
|
|
9
|
+
transcript_idx: int | None
|
|
10
|
+
block_idx: int
|
|
11
|
+
action_unit_idx: int | None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def parse_citations_single_run(text: str) -> list[Citation]:
|
|
15
|
+
"""
|
|
16
|
+
Parse citations from text in the format described by SINGLE_BLOCK_CITE_INSTRUCTION.
|
|
17
|
+
|
|
18
|
+
Supported formats:
|
|
19
|
+
- Single block: [T<key>B<idx>]
|
|
20
|
+
- Multiple blocks: [T<key1>B<idx1>, T<key2>B<idx2>, ...]
|
|
21
|
+
- Dash-separated blocks: [T<key1>B<idx1>-T<key2>B<idx2>]
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
text: The text to parse citations from
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
A list of Citation objects with start_idx and end_idx representing
|
|
28
|
+
the character positions in the text (excluding brackets)
|
|
29
|
+
"""
|
|
30
|
+
citations: list[Citation] = []
|
|
31
|
+
|
|
32
|
+
# Find all bracketed content first
|
|
33
|
+
bracket_pattern = r"\[(.*?)\]"
|
|
34
|
+
bracket_matches = re.finditer(bracket_pattern, text)
|
|
35
|
+
|
|
36
|
+
for bracket_match in bracket_matches:
|
|
37
|
+
bracket_content = bracket_match.group(1)
|
|
38
|
+
# Starting position of the bracket content (excluding '[')
|
|
39
|
+
content_start_pos = bracket_match.start() + 1
|
|
40
|
+
|
|
41
|
+
# Split by commas if present
|
|
42
|
+
parts = [part.strip() for part in bracket_content.split(",")]
|
|
43
|
+
|
|
44
|
+
for part in parts:
|
|
45
|
+
# Check if this part contains a dash (range citation)
|
|
46
|
+
if "-" in part:
|
|
47
|
+
# Split by dash and process each sub-part
|
|
48
|
+
dash_parts = [dash_part.strip() for dash_part in part.split("-")]
|
|
49
|
+
for dash_part in dash_parts:
|
|
50
|
+
# Check for single block citation: T<key>B<idx>
|
|
51
|
+
single_match = re.match(r"T(\d+)B(\d+)", dash_part)
|
|
52
|
+
if single_match:
|
|
53
|
+
transcript_idx = int(single_match.group(1))
|
|
54
|
+
block_idx = int(single_match.group(2))
|
|
55
|
+
|
|
56
|
+
# Find position within the original text
|
|
57
|
+
citation_text = f"T{transcript_idx}B{block_idx}"
|
|
58
|
+
part_pos_in_content = bracket_content.find(dash_part)
|
|
59
|
+
ref_pos = content_start_pos + part_pos_in_content
|
|
60
|
+
ref_end = ref_pos + len(citation_text)
|
|
61
|
+
|
|
62
|
+
# Check if this citation overlaps with any existing citation
|
|
63
|
+
if not any(
|
|
64
|
+
citation["start_idx"] <= ref_pos < citation["end_idx"]
|
|
65
|
+
or citation["start_idx"] < ref_end <= citation["end_idx"]
|
|
66
|
+
for citation in citations
|
|
67
|
+
):
|
|
68
|
+
citations.append(
|
|
69
|
+
Citation(
|
|
70
|
+
start_idx=ref_pos,
|
|
71
|
+
end_idx=ref_end,
|
|
72
|
+
agent_run_idx=None,
|
|
73
|
+
transcript_idx=transcript_idx,
|
|
74
|
+
block_idx=block_idx,
|
|
75
|
+
action_unit_idx=None,
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
# Check for single block citation: T<key>B<idx>
|
|
80
|
+
single_match = re.match(r"T(\d+)B(\d+)", part)
|
|
81
|
+
if single_match:
|
|
82
|
+
transcript_idx = int(single_match.group(1))
|
|
83
|
+
block_idx = int(single_match.group(2))
|
|
84
|
+
|
|
85
|
+
# Find position within the original text
|
|
86
|
+
citation_text = f"T{transcript_idx}B{block_idx}"
|
|
87
|
+
part_pos_in_content = bracket_content.find(part)
|
|
88
|
+
ref_pos = content_start_pos + part_pos_in_content
|
|
89
|
+
ref_end = ref_pos + len(citation_text)
|
|
90
|
+
|
|
91
|
+
# Check if this citation overlaps with any existing citation
|
|
92
|
+
if not any(
|
|
93
|
+
citation["start_idx"] <= ref_pos < citation["end_idx"]
|
|
94
|
+
or citation["start_idx"] < ref_end <= citation["end_idx"]
|
|
95
|
+
for citation in citations
|
|
96
|
+
):
|
|
97
|
+
citations.append(
|
|
98
|
+
Citation(
|
|
99
|
+
start_idx=ref_pos,
|
|
100
|
+
end_idx=ref_end,
|
|
101
|
+
agent_run_idx=None,
|
|
102
|
+
transcript_idx=transcript_idx,
|
|
103
|
+
block_idx=block_idx,
|
|
104
|
+
action_unit_idx=None,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return citations
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def parse_citations_multi_run(text: str) -> list[Citation]:
|
|
112
|
+
"""
|
|
113
|
+
Parse citations from text in the format described by MULTI_BLOCK_CITE_INSTRUCTION.
|
|
114
|
+
|
|
115
|
+
Supported formats:
|
|
116
|
+
- Single block in transcript: [R<idx>T<key>B<idx>] or ([R<idx>T<key>B<idx>])
|
|
117
|
+
- Multiple blocks: [R<idx1>T<key1>B<idx1>][R<idx2>T<key2>B<idx2>]
|
|
118
|
+
- Comma-separated blocks: [R<idx1>T<key1>B<idx1>, R<idx2>T<key2>B<idx2>, ...]
|
|
119
|
+
- Dash-separated blocks: [R<idx1>T<key1>B<idx1>-R<idx2>T<key2>B<idx2>]
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
text: The text to parse citations from
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
A list of Citation objects with start_idx and end_idx representing
|
|
126
|
+
the character positions in the text (excluding brackets)
|
|
127
|
+
"""
|
|
128
|
+
citations: list[Citation] = []
|
|
129
|
+
|
|
130
|
+
# Find all content within brackets - this handles nested brackets too
|
|
131
|
+
bracket_pattern = r"\[([^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*)\]"
|
|
132
|
+
# Also handle optional parentheses around the brackets
|
|
133
|
+
paren_bracket_pattern = r"\(\[([^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*)\]\)"
|
|
134
|
+
|
|
135
|
+
# Single citation pattern
|
|
136
|
+
single_pattern = r"R(\d+)T(\d+)B(\d+)"
|
|
137
|
+
|
|
138
|
+
# Find all bracket matches
|
|
139
|
+
for pattern in [bracket_pattern, paren_bracket_pattern]:
|
|
140
|
+
matches = re.finditer(pattern, text)
|
|
141
|
+
for match in matches:
|
|
142
|
+
# Get the content inside brackets
|
|
143
|
+
if pattern == bracket_pattern:
|
|
144
|
+
content = match.group(1)
|
|
145
|
+
start_pos = match.start() + 1 # +1 to skip the opening bracket
|
|
146
|
+
else:
|
|
147
|
+
content = match.group(1)
|
|
148
|
+
start_pos = match.start() + 2 # +2 to skip the opening parenthesis and bracket
|
|
149
|
+
|
|
150
|
+
# Split by comma if present
|
|
151
|
+
items = [item.strip() for item in content.split(",")]
|
|
152
|
+
|
|
153
|
+
for item in items:
|
|
154
|
+
# Check if this item contains a dash (range citation)
|
|
155
|
+
if "-" in item:
|
|
156
|
+
# Split by dash and process each sub-item
|
|
157
|
+
dash_items = [dash_item.strip() for dash_item in item.split("-")]
|
|
158
|
+
for dash_item in dash_items:
|
|
159
|
+
# Check for single citation
|
|
160
|
+
single_match = re.match(single_pattern, dash_item)
|
|
161
|
+
if single_match:
|
|
162
|
+
agent_run_idx = int(single_match.group(1))
|
|
163
|
+
transcript_idx = int(single_match.group(2))
|
|
164
|
+
block_idx = int(single_match.group(3))
|
|
165
|
+
|
|
166
|
+
# Calculate position in the original text
|
|
167
|
+
citation_text = f"R{agent_run_idx}T{transcript_idx}B{block_idx}"
|
|
168
|
+
citation_start = text.find(citation_text, start_pos)
|
|
169
|
+
citation_end = citation_start + len(citation_text)
|
|
170
|
+
|
|
171
|
+
# Move start_pos for the next item if there are more items
|
|
172
|
+
start_pos = citation_end
|
|
173
|
+
|
|
174
|
+
# Avoid duplicate citations
|
|
175
|
+
if not any(
|
|
176
|
+
citation["start_idx"] == citation_start
|
|
177
|
+
and citation["end_idx"] == citation_end
|
|
178
|
+
for citation in citations
|
|
179
|
+
):
|
|
180
|
+
citations.append(
|
|
181
|
+
Citation(
|
|
182
|
+
start_idx=citation_start,
|
|
183
|
+
end_idx=citation_end,
|
|
184
|
+
agent_run_idx=agent_run_idx,
|
|
185
|
+
transcript_idx=transcript_idx,
|
|
186
|
+
block_idx=block_idx,
|
|
187
|
+
action_unit_idx=None,
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
# Check for single citation
|
|
192
|
+
single_match = re.match(single_pattern, item)
|
|
193
|
+
if single_match:
|
|
194
|
+
agent_run_idx = int(single_match.group(1))
|
|
195
|
+
transcript_idx = int(single_match.group(2))
|
|
196
|
+
block_idx = int(single_match.group(3))
|
|
197
|
+
|
|
198
|
+
# Calculate position in the original text
|
|
199
|
+
citation_text = f"R{agent_run_idx}T{transcript_idx}B{block_idx}"
|
|
200
|
+
citation_start = text.find(citation_text, start_pos)
|
|
201
|
+
citation_end = citation_start + len(citation_text)
|
|
202
|
+
|
|
203
|
+
# Move start_pos for the next item if there are more items
|
|
204
|
+
start_pos = citation_end
|
|
205
|
+
|
|
206
|
+
# Avoid duplicate citations
|
|
207
|
+
if not any(
|
|
208
|
+
citation["start_idx"] == citation_start
|
|
209
|
+
and citation["end_idx"] == citation_end
|
|
210
|
+
for citation in citations
|
|
211
|
+
):
|
|
212
|
+
citations.append(
|
|
213
|
+
Citation(
|
|
214
|
+
start_idx=citation_start,
|
|
215
|
+
end_idx=citation_end,
|
|
216
|
+
agent_run_idx=agent_run_idx,
|
|
217
|
+
transcript_idx=transcript_idx,
|
|
218
|
+
block_idx=block_idx,
|
|
219
|
+
action_unit_idx=None,
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return citations
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Type, Union
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Discriminator, Field, field_validator
|
|
7
|
+
from sqlalchemy import ColumnElement, and_, or_
|
|
8
|
+
|
|
9
|
+
from docent._log_util import get_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from docent_core._db_service.schemas.tables import SQLAAgentRun
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseFrameFilter(BaseModel):
|
|
18
|
+
"""Base class for all frame filters."""
|
|
19
|
+
|
|
20
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
21
|
+
name: str | None = None
|
|
22
|
+
supports_sql: bool = True # All filters must support SQL
|
|
23
|
+
|
|
24
|
+
def to_sqla_where_clause(self, table: Type["SQLAAgentRun"]) -> ColumnElement[bool] | None:
|
|
25
|
+
"""Convert this filter to a SQLAlchemy WHERE clause.
|
|
26
|
+
|
|
27
|
+
All filters must implement this method to support SQL execution.
|
|
28
|
+
"""
|
|
29
|
+
raise NotImplementedError(
|
|
30
|
+
f"Filter {self.__class__.__name__} must implement to_sqla_where_clause"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PrimitiveFilter(BaseFrameFilter):
|
|
35
|
+
"""Filter that applies a primitive operation to a metadata field."""
|
|
36
|
+
|
|
37
|
+
type: Literal["primitive"] = "primitive"
|
|
38
|
+
key_path: list[str]
|
|
39
|
+
value: Any
|
|
40
|
+
op: Literal[">", ">=", "<", "<=", "==", "!=", "~*", "!~*"]
|
|
41
|
+
|
|
42
|
+
def to_sqla_where_clause(self, table: Type["SQLAAgentRun"]) -> ColumnElement[bool] | None:
|
|
43
|
+
"""Convert this filter to a SQLAlchemy WHERE clause."""
|
|
44
|
+
|
|
45
|
+
mode = self.key_path[0]
|
|
46
|
+
|
|
47
|
+
# Extract value from JSONB using the table parameter
|
|
48
|
+
if mode == "text":
|
|
49
|
+
sqla_value = table.text_for_search # type: ignore
|
|
50
|
+
elif mode == "metadata":
|
|
51
|
+
sqla_value = table.metadata_json # type: ignore
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Unsupported mode: {mode}")
|
|
54
|
+
|
|
55
|
+
for key in self.key_path[1:]:
|
|
56
|
+
sqla_value = sqla_value[key]
|
|
57
|
+
|
|
58
|
+
# Cast the extracted value to the correct type
|
|
59
|
+
# This is only necessary for metadata which is JSONB
|
|
60
|
+
if mode == "metadata":
|
|
61
|
+
if isinstance(self.value, str):
|
|
62
|
+
sqla_value = sqla_value.as_string()
|
|
63
|
+
elif isinstance(self.value, bool):
|
|
64
|
+
sqla_value = sqla_value.as_boolean()
|
|
65
|
+
elif isinstance(self.value, float) or isinstance(self.value, int): # type: ignore warning about unnecessary comparison
|
|
66
|
+
# if self.value is an int, we may still need to do sql comparisons with floats
|
|
67
|
+
sqla_value = sqla_value.as_float()
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Unsupported value type: {type(self.value)}")
|
|
70
|
+
|
|
71
|
+
# Handle different operations using SQLAlchemy expressions
|
|
72
|
+
if self.op == "==":
|
|
73
|
+
return sqla_value == self.value
|
|
74
|
+
elif self.op == "!=":
|
|
75
|
+
return sqla_value != self.value
|
|
76
|
+
elif self.op == ">":
|
|
77
|
+
return sqla_value > self.value
|
|
78
|
+
elif self.op == ">=":
|
|
79
|
+
return sqla_value >= self.value
|
|
80
|
+
elif self.op == "<":
|
|
81
|
+
return sqla_value < self.value
|
|
82
|
+
elif self.op == "<=":
|
|
83
|
+
return sqla_value <= self.value
|
|
84
|
+
elif self.op == "~*":
|
|
85
|
+
return sqla_value.op("~*")(self.value)
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unsupported operation: {self.op}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ComplexFilter(BaseFrameFilter):
|
|
91
|
+
"""Filter that combines multiple filters with AND/OR/NOT logic."""
|
|
92
|
+
|
|
93
|
+
type: Literal["complex"] = "complex"
|
|
94
|
+
filters: list[FrameFilter]
|
|
95
|
+
op: Literal["and", "or"] = "and"
|
|
96
|
+
|
|
97
|
+
@field_validator("filters")
|
|
98
|
+
@classmethod
|
|
99
|
+
def validate_filters(cls, v: list[FrameFilter]) -> list[FrameFilter]:
|
|
100
|
+
if not v:
|
|
101
|
+
raise ValueError("ComplexFilter must have at least one filter")
|
|
102
|
+
return v
|
|
103
|
+
|
|
104
|
+
def to_sqla_where_clause(self, table: Type["SQLAAgentRun"]) -> ColumnElement[bool] | None:
|
|
105
|
+
"""Convert this filter to a SQLAlchemy WHERE clause."""
|
|
106
|
+
|
|
107
|
+
# Get WHERE clauses for all sub-filters
|
|
108
|
+
where_clauses: list[ColumnElement[bool]] = []
|
|
109
|
+
for filter_obj in self.filters:
|
|
110
|
+
where_clause = filter_obj.to_sqla_where_clause(table)
|
|
111
|
+
if where_clause is not None:
|
|
112
|
+
where_clauses.append(where_clause)
|
|
113
|
+
|
|
114
|
+
if not where_clauses:
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
# Apply the operation
|
|
118
|
+
if self.op == "and":
|
|
119
|
+
result = and_(*where_clauses)
|
|
120
|
+
elif self.op == "or":
|
|
121
|
+
result = or_(*where_clauses)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"Unsupported operation: {self.op}")
|
|
124
|
+
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class AgentRunIdFilter(BaseFrameFilter):
|
|
129
|
+
"""Filter that matches specific agent run IDs."""
|
|
130
|
+
|
|
131
|
+
type: Literal["agent_run_id"] = "agent_run_id"
|
|
132
|
+
agent_run_ids: list[str]
|
|
133
|
+
|
|
134
|
+
@field_validator("agent_run_ids")
|
|
135
|
+
@classmethod
|
|
136
|
+
def validate_agent_run_ids(cls, v: list[str]) -> list[str]:
|
|
137
|
+
if not v:
|
|
138
|
+
raise ValueError("AgentRunIdFilter must have at least one agent run ID")
|
|
139
|
+
return v
|
|
140
|
+
|
|
141
|
+
def to_sqla_where_clause(self, table: Type["SQLAAgentRun"]) -> ColumnElement[bool] | None:
|
|
142
|
+
"""Convert to SQLAlchemy WHERE clause for agent run ID filtering."""
|
|
143
|
+
return table.id.in_(self.agent_run_ids)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SearchResultPredicateFilter(BaseFrameFilter):
|
|
147
|
+
"""Filter that applies a predicate to search results."""
|
|
148
|
+
|
|
149
|
+
type: Literal["search_result_predicate"] = "search_result_predicate"
|
|
150
|
+
predicate: str
|
|
151
|
+
search_query: str
|
|
152
|
+
|
|
153
|
+
def to_sqla_where_clause(self, table: Type["SQLAAgentRun"]) -> ColumnElement[bool] | None:
|
|
154
|
+
"""Convert to SQLAlchemy WHERE clause for search result filtering."""
|
|
155
|
+
# This filter requires joining with search results table
|
|
156
|
+
# For now, we'll return None to indicate it needs special handling
|
|
157
|
+
# In practice, this would join with the search_results table
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class SearchResultExistsFilter(BaseFrameFilter):
|
|
162
|
+
"""Filter that checks if search results exist."""
|
|
163
|
+
|
|
164
|
+
type: Literal["search_result_exists"] = "search_result_exists"
|
|
165
|
+
search_query: str
|
|
166
|
+
|
|
167
|
+
def to_sqla_where_clause(self, table: Type["SQLAAgentRun"]) -> ColumnElement[bool] | None:
|
|
168
|
+
"""Convert to SQLAlchemy WHERE clause for search result existence filtering."""
|
|
169
|
+
# This filter requires joining with search results table
|
|
170
|
+
# For now, we'll return None to indicate it needs special handling
|
|
171
|
+
# In practice, this would join with the search_results table
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
FrameFilter = Annotated[
|
|
176
|
+
Union[
|
|
177
|
+
PrimitiveFilter,
|
|
178
|
+
ComplexFilter,
|
|
179
|
+
AgentRunIdFilter,
|
|
180
|
+
SearchResultPredicateFilter,
|
|
181
|
+
SearchResultExistsFilter,
|
|
182
|
+
],
|
|
183
|
+
Discriminator("type"),
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def parse_filter_dict(filter_dict: dict[str, Any]) -> FrameFilter:
|
|
188
|
+
"""Parse a filter dictionary into a FrameFilter object."""
|
|
189
|
+
filter_type = filter_dict.get("type")
|
|
190
|
+
|
|
191
|
+
if filter_type == "primitive":
|
|
192
|
+
return PrimitiveFilter(**filter_dict)
|
|
193
|
+
elif filter_type == "complex":
|
|
194
|
+
# Recursively parse nested filters
|
|
195
|
+
nested_filters = [parse_filter_dict(f) for f in filter_dict.get("filters", [])]
|
|
196
|
+
filter_dict["filters"] = nested_filters
|
|
197
|
+
return ComplexFilter(**filter_dict)
|
|
198
|
+
elif filter_type == "agent_run_id":
|
|
199
|
+
return AgentRunIdFilter(**filter_dict)
|
|
200
|
+
elif filter_type == "search_result_predicate":
|
|
201
|
+
return SearchResultPredicateFilter(**filter_dict)
|
|
202
|
+
elif filter_type == "search_result_exists":
|
|
203
|
+
return SearchResultExistsFilter(**filter_dict)
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError(f"Unknown filter type: {filter_type}")
|