notte-agent 0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- notte_agent/README.md +58 -0
- notte_agent/__init__.py +7 -0
- notte_agent/common/__init__.py +0 -0
- notte_agent/common/base.py +14 -0
- notte_agent/common/captcha_detector.py +87 -0
- notte_agent/common/config.py +219 -0
- notte_agent/common/conversation.py +246 -0
- notte_agent/common/notifier.py +55 -0
- notte_agent/common/parser.py +78 -0
- notte_agent/common/perception.py +21 -0
- notte_agent/common/prompt.py +15 -0
- notte_agent/common/safe_executor.py +100 -0
- notte_agent/common/trajectory_history.py +100 -0
- notte_agent/common/types.py +41 -0
- notte_agent/common/validator.py +90 -0
- notte_agent/falco/__init__.py +0 -0
- notte_agent/falco/agent.py +343 -0
- notte_agent/falco/perception.py +83 -0
- notte_agent/falco/prompt.py +132 -0
- notte_agent/falco/prompts/system_prompt_multi_actions.md +107 -0
- notte_agent/falco/prompts/system_prompt_single_action.md +107 -0
- notte_agent/falco/trajectory_history.py +42 -0
- notte_agent/falco/types.py +132 -0
- notte_agent/gufo/__init__.py +0 -0
- notte_agent/gufo/agent.py +180 -0
- notte_agent/gufo/parser.py +79 -0
- notte_agent/gufo/perception.py +53 -0
- notte_agent/gufo/prompt.py +61 -0
- notte_agent/gufo/system.md +8 -0
- notte_agent/main.py +77 -0
- notte_agent/py.typed +0 -0
- notte_agent-0.0.dev0.dist-info/METADATA +8 -0
- notte_agent-0.0.dev0.dist-info/RECORD +34 -0
- notte_agent-0.0.dev0.dist-info/WHEEL +4 -0
notte_agent/README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# How to build an LLM agent with *Notte*
|
|
2
|
+
|
|
3
|
+
This guide explains how to build a custom LLM agent using *Notte*. The example in `agent.py` demonstrates a basic implementation that you can customize for your specific needs.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
*Notte* provides a flexible environment for web automation that can be controlled through an API. To build an agent with *Notte*, you need:
|
|
8
|
+
|
|
9
|
+
1. An agent implementation that coordinates between your LLM and the *Notte* environment
|
|
10
|
+
2. A parser that formats *Notte*'s outputs into prompts suitable for your LLM
|
|
11
|
+
3. A way to interpret the LLM's responses back into *Notte* commands
|
|
12
|
+
|
|
13
|
+
## Key Components
|
|
14
|
+
|
|
15
|
+
### Agent
|
|
16
|
+
|
|
17
|
+
The `Agent` class in `agent.py` shows how to:
|
|
18
|
+
- Initialize a connection to your LLM service
|
|
19
|
+
- Manage the conversation flow between the LLM and *Notte*
|
|
20
|
+
- Track the state of task completion
|
|
21
|
+
|
|
22
|
+
### Parser
|
|
23
|
+
|
|
24
|
+
The parser is crucial for translating between *Notte* and your LLM. You'll need to:
|
|
25
|
+
|
|
26
|
+
1. Create a custom parser (by extending `BaseNotteParser` or implementing the `Parser` interface)
|
|
27
|
+
2. Define how to format:
|
|
28
|
+
- Observations from web pages
|
|
29
|
+
- Available actions
|
|
30
|
+
- Data extraction results
|
|
31
|
+
- Task completion status
|
|
32
|
+
|
|
33
|
+
The provided `BaseNotteParser` is a simple example that you should modify based on your needs. Consider:
|
|
34
|
+
- The prompt format your LLM works best with
|
|
35
|
+
- How to structure web observations for your specific tasks
|
|
36
|
+
- What action format makes sense for your use case
|
|
37
|
+
- How to handle task completion and data extraction
|
|
38
|
+
|
|
39
|
+
## Example Implementation
|
|
40
|
+
|
|
41
|
+
See `agent.py` for a basic implementation. Key points to customize:
|
|
42
|
+
- The parser implementation
|
|
43
|
+
- The prompt engineering in the conversation flow
|
|
44
|
+
- How task completion is determined
|
|
45
|
+
- Error handling and retry logic
|
|
46
|
+
|
|
47
|
+
## Best Practices
|
|
48
|
+
|
|
49
|
+
1. **Custom Parser**: Don't just use the `BaseNotteParser` as-is. Create your own parser that:
|
|
50
|
+
- Formats observations in a way that makes sense for your LLM
|
|
51
|
+
- Structures action possibilities clearly
|
|
52
|
+
- Handles task-specific data extraction
|
|
53
|
+
|
|
54
|
+
2. **Prompt Engineering**: Carefully design your system prompt and conversation flow
|
|
55
|
+
|
|
56
|
+
3. **Error Handling**: Add robust error handling for both LLM and *Notte* interactions
|
|
57
|
+
|
|
58
|
+
4. **Testing**: Test your parser and agent with different scenarios
|
notte_agent/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from notte_browser.session import NotteSession
|
|
4
|
+
|
|
5
|
+
from notte_agent.common.types import AgentResponse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseAgent(ABC):
|
|
9
|
+
def __init__(self, session: NotteSession):
|
|
10
|
+
self.session: NotteSession = session
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
async def run(self, task: str, url: str | None = None) -> AgentResponse:
|
|
14
|
+
pass
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import final
|
|
2
|
+
|
|
3
|
+
import chevron
|
|
4
|
+
from notte_browser.session import TrajectoryStep
|
|
5
|
+
from notte_core.llms.engine import LLMEngine
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from notte_agent.common.conversation import Conversation
|
|
9
|
+
from notte_agent.common.perception import BasePerception
|
|
10
|
+
|
|
11
|
+
system_rules = """
|
|
12
|
+
You are a captcha detector for web pages.
|
|
13
|
+
Analyze the provided screenshot and determine if there is a captcha present on the page.
|
|
14
|
+
A captcha can be in various forms such as:
|
|
15
|
+
- Image-based challenges
|
|
16
|
+
- Text-based challenges
|
|
17
|
+
- Checkbox-based verification
|
|
18
|
+
- Audio-based challenges
|
|
19
|
+
- Math problems
|
|
20
|
+
- Slider puzzles
|
|
21
|
+
- Blocked by network security
|
|
22
|
+
- etc.
|
|
23
|
+
|
|
24
|
+
Return a JSON object with 3 keys: `has_captcha`, `captcha_type`, and `description`:
|
|
25
|
+
- `has_captcha` is a boolean indicating if a captcha is present
|
|
26
|
+
- `captcha_type` is a string describing the type of captcha (or "none" if no captcha)
|
|
27
|
+
- `description` is a string providing details about the captcha's appearance and location
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
```json
|
|
31
|
+
{{&example}}
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Your turn:
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CaptchaDetection(BaseModel):
|
|
39
|
+
has_captcha: bool
|
|
40
|
+
captcha_type: str
|
|
41
|
+
description: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@final
|
|
45
|
+
class CaptchaDetector:
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
llm: LLMEngine,
|
|
49
|
+
perception: BasePerception,
|
|
50
|
+
use_vision: bool = True,
|
|
51
|
+
include_attributes: bool = True,
|
|
52
|
+
):
|
|
53
|
+
self.use_vision = use_vision
|
|
54
|
+
self.include_attributes = include_attributes
|
|
55
|
+
self.llm: LLMEngine = llm
|
|
56
|
+
self.conv: Conversation = Conversation()
|
|
57
|
+
self.perception: BasePerception = perception
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def example() -> CaptchaDetection:
|
|
61
|
+
return CaptchaDetection(
|
|
62
|
+
has_captcha=True,
|
|
63
|
+
captcha_type="image-based",
|
|
64
|
+
description="A grid of 9 images is shown with the instruction 'Select all images containing traffic lights'",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def detection_message(self, step: TrajectoryStep) -> str:
|
|
68
|
+
return f"""
|
|
69
|
+
Current page screenshot:
|
|
70
|
+
{self.perception.perceive(step.obs)}
|
|
71
|
+
|
|
72
|
+
Current page state:
|
|
73
|
+
{step.action.model_dump_json(exclude_unset=True)}
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def detect(
|
|
77
|
+
self,
|
|
78
|
+
step: TrajectoryStep,
|
|
79
|
+
) -> CaptchaDetection:
|
|
80
|
+
"""Detect if there is a captcha present in the current page screenshot"""
|
|
81
|
+
self.conv.reset()
|
|
82
|
+
system_prompt = chevron.render(system_rules, {"example": self.example().model_dump_json()})
|
|
83
|
+
self.conv.add_system_message(content=system_prompt)
|
|
84
|
+
self.conv.add_user_message(content=self.detection_message(step))
|
|
85
|
+
|
|
86
|
+
answer: CaptchaDetection = self.llm.structured_completion(self.conv.messages(), CaptchaDetection)
|
|
87
|
+
return answer
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from argparse import ArgumentParser, Namespace
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
from typing import Any, ClassVar, Self, get_origin, get_type_hints
|
|
6
|
+
|
|
7
|
+
from notte_browser.session import NotteSessionConfig
|
|
8
|
+
from notte_core.common.config import FrozenConfig
|
|
9
|
+
from notte_core.llms.engine import LlmModel
|
|
10
|
+
from notte_sdk.types import DEFAULT_MAX_NB_STEPS
|
|
11
|
+
from pydantic import Field, model_validator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RaiseCondition(StrEnum):
|
|
15
|
+
"""How to raise an error when the agent fails to complete a step.
|
|
16
|
+
|
|
17
|
+
Either immediately upon failure, after retry, or never.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
IMMEDIATELY = "immediately"
|
|
21
|
+
RETRY = "retry"
|
|
22
|
+
NEVER = "never"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DefaultAgentArgs(StrEnum):
|
|
26
|
+
SESSION_DISABLE_WEB_SECURITY = "disable_web_security"
|
|
27
|
+
SESSION_HEADLESS = "headless"
|
|
28
|
+
SESSION_PERCEPTION_MODEL = "perception_model"
|
|
29
|
+
SESSION_MAX_STEPS = "max_steps"
|
|
30
|
+
|
|
31
|
+
def with_prefix(self: Self, prefix: str = "session") -> str:
|
|
32
|
+
return f"{prefix}.{self.value}"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AgentConfig(FrozenConfig, ABC):
|
|
36
|
+
# make session private to avoid exposing the NotteSessionConfig class
|
|
37
|
+
session: NotteSessionConfig = Field(init=False)
|
|
38
|
+
reasoning_model: str = Field(
|
|
39
|
+
default=LlmModel.default(), description="The model to use for reasoning (i.e taking actions)."
|
|
40
|
+
)
|
|
41
|
+
include_screenshot: bool = Field(default=False, description="Whether to include a screenshot in the response.")
|
|
42
|
+
max_history_tokens: int | None = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="The maximum number of tokens in the history. When the history exceeds this limit, the oldest messages are discarded.",
|
|
45
|
+
)
|
|
46
|
+
max_error_length: int = Field(
|
|
47
|
+
default=500, description="The maximum length of an error message to be forwarded to the reasoning model."
|
|
48
|
+
)
|
|
49
|
+
raise_condition: RaiseCondition = Field(
|
|
50
|
+
default=RaiseCondition.RETRY, description="How to raise an error when the agent fails to complete a step."
|
|
51
|
+
)
|
|
52
|
+
max_consecutive_failures: int = Field(
|
|
53
|
+
default=3, description="The maximum number of consecutive failures before the agent gives up."
|
|
54
|
+
)
|
|
55
|
+
force_session: bool | None = Field(
|
|
56
|
+
default=None,
|
|
57
|
+
description="Whether to allow the user to set the session or not.",
|
|
58
|
+
)
|
|
59
|
+
human_in_the_loop: bool = Field(default=False, description="Whether to enable human-in-the-loop mode.")
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def default_session(cls) -> NotteSessionConfig:
|
|
64
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
65
|
+
|
|
66
|
+
@model_validator(mode="before")
|
|
67
|
+
@classmethod
|
|
68
|
+
def set_session(cls, values: dict[str, Any]) -> dict[str, Any]:
|
|
69
|
+
if "session" in values:
|
|
70
|
+
if "force_session" in values and values["force_session"]:
|
|
71
|
+
del values["force_session"]
|
|
72
|
+
return values
|
|
73
|
+
raise ValueError("Session should not be set by the user. Set `default_session` instead.")
|
|
74
|
+
values["session"] = cls.default_session() # Set the session field using the subclass's method
|
|
75
|
+
return values
|
|
76
|
+
|
|
77
|
+
def groq(self: Self, deep: bool = True) -> Self:
|
|
78
|
+
return self.model(LlmModel.groq, deep=deep)
|
|
79
|
+
|
|
80
|
+
def openai(self: Self, deep: bool = True) -> Self:
|
|
81
|
+
return self.model(LlmModel.openai, deep=deep)
|
|
82
|
+
|
|
83
|
+
def gemini(self: Self, deep: bool = True) -> Self:
|
|
84
|
+
return self.model(LlmModel.gemini, deep=deep)
|
|
85
|
+
|
|
86
|
+
def cerebras(self: Self, deep: bool = True) -> Self:
|
|
87
|
+
return self.model(LlmModel.cerebras, deep=deep)
|
|
88
|
+
|
|
89
|
+
def model(self: Self, model: LlmModel, deep: bool = True) -> Self:
|
|
90
|
+
config = self._copy_and_validate(reasoning_model=model, max_history_tokens=LlmModel.context_length(model))
|
|
91
|
+
if deep:
|
|
92
|
+
config = config.map_session(lambda session: session.model(model))
|
|
93
|
+
return config
|
|
94
|
+
|
|
95
|
+
def use_vision(self: Self, value: bool = True) -> Self:
|
|
96
|
+
return self._copy_and_validate(include_screenshot=value)
|
|
97
|
+
|
|
98
|
+
def set_human_in_the_loop(self: Self, value: bool = True) -> Self:
|
|
99
|
+
return self._copy_and_validate(human_in_the_loop=value)
|
|
100
|
+
|
|
101
|
+
def dev_mode(self: Self) -> Self:
|
|
102
|
+
return self._copy_and_validate(
|
|
103
|
+
raise_condition=RaiseCondition.IMMEDIATELY,
|
|
104
|
+
max_error_length=1000,
|
|
105
|
+
session=self.session.dev_mode(),
|
|
106
|
+
force_session=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def set_raise_condition(self: Self, value: RaiseCondition) -> Self:
|
|
110
|
+
return self._copy_and_validate(raise_condition=value)
|
|
111
|
+
|
|
112
|
+
def map_session(self: Self, ft: Callable[[NotteSessionConfig], NotteSessionConfig]) -> Self:
|
|
113
|
+
return self._copy_and_validate(session=ft(self.session), force_session=True)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _get_arg_type(python_type: Any) -> Any:
|
|
117
|
+
"""Maps Python types to argparse types."""
|
|
118
|
+
type_map = {
|
|
119
|
+
str: str,
|
|
120
|
+
int: int,
|
|
121
|
+
float: float,
|
|
122
|
+
bool: bool,
|
|
123
|
+
}
|
|
124
|
+
return type_map.get(python_type, str)
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def create_base_parser() -> ArgumentParser:
|
|
128
|
+
"""Creates a base ArgumentParser with all the fields from the config."""
|
|
129
|
+
parser = ArgumentParser()
|
|
130
|
+
_ = parser.add_argument(
|
|
131
|
+
f"--{DefaultAgentArgs.SESSION_HEADLESS.with_prefix()}",
|
|
132
|
+
action="store_true",
|
|
133
|
+
help="Whether to run the browser in headless mode.",
|
|
134
|
+
)
|
|
135
|
+
_ = parser.add_argument(
|
|
136
|
+
f"--{DefaultAgentArgs.SESSION_DISABLE_WEB_SECURITY.with_prefix()}",
|
|
137
|
+
action="store_true",
|
|
138
|
+
help="Whether disable web security.",
|
|
139
|
+
)
|
|
140
|
+
_ = parser.add_argument(
|
|
141
|
+
f"--{DefaultAgentArgs.SESSION_PERCEPTION_MODEL.with_prefix()}",
|
|
142
|
+
type=str,
|
|
143
|
+
default=LlmModel.default(),
|
|
144
|
+
help="The model to use for perception.",
|
|
145
|
+
)
|
|
146
|
+
_ = parser.add_argument(
|
|
147
|
+
f"--{DefaultAgentArgs.SESSION_MAX_STEPS.with_prefix()}",
|
|
148
|
+
type=int,
|
|
149
|
+
default=DEFAULT_MAX_NB_STEPS,
|
|
150
|
+
help="The maximum number of steps the agent can take.",
|
|
151
|
+
)
|
|
152
|
+
return parser
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def create_parser(cls) -> ArgumentParser:
|
|
156
|
+
"""Creates an ArgumentParser with all the fields from the config."""
|
|
157
|
+
parser = cls.create_base_parser()
|
|
158
|
+
hints = get_type_hints(cls)
|
|
159
|
+
|
|
160
|
+
for field_name, field_info in cls.model_fields.items():
|
|
161
|
+
if field_name == "session":
|
|
162
|
+
continue
|
|
163
|
+
field_type = hints.get(field_name)
|
|
164
|
+
if get_origin(field_type) is ClassVar:
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
default = field_info.default
|
|
168
|
+
help_text = field_info.description or "no description available"
|
|
169
|
+
arg_type = cls._get_arg_type(field_type)
|
|
170
|
+
|
|
171
|
+
_ = parser.add_argument(
|
|
172
|
+
f"--{field_name.replace('_', '-')}",
|
|
173
|
+
type=arg_type,
|
|
174
|
+
default=default,
|
|
175
|
+
help=f"{help_text} (default: {default})",
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return parser
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def from_args(cls: type[Self], args: Namespace) -> Self:
|
|
182
|
+
"""Creates an AgentConfig from a Namespace of arguments.
|
|
183
|
+
|
|
184
|
+
The return type will match the class that called this method.
|
|
185
|
+
"""
|
|
186
|
+
disallowed_args = ["task", "session.window.headless"]
|
|
187
|
+
|
|
188
|
+
session_args = {
|
|
189
|
+
k.replace("session.", "").replace("-", "_"): v
|
|
190
|
+
for k, v in vars(args).items()
|
|
191
|
+
if k.startswith("session.") and k not in disallowed_args
|
|
192
|
+
}
|
|
193
|
+
agent_args = {
|
|
194
|
+
k.replace("-", "_"): v
|
|
195
|
+
for k, v in vars(args).items()
|
|
196
|
+
if not k.startswith("session.") and k not in disallowed_args
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
def update_session(session: NotteSessionConfig) -> NotteSessionConfig:
|
|
200
|
+
operations: list[Callable[[NotteSessionConfig], NotteSessionConfig]] = []
|
|
201
|
+
if DefaultAgentArgs.SESSION_HEADLESS in session_args:
|
|
202
|
+
headless = session_args[DefaultAgentArgs.SESSION_HEADLESS]
|
|
203
|
+
operations.append(lambda session: session.headless(headless))
|
|
204
|
+
del session_args[DefaultAgentArgs.SESSION_HEADLESS]
|
|
205
|
+
if DefaultAgentArgs.SESSION_DISABLE_WEB_SECURITY in session_args:
|
|
206
|
+
disable_web_security = session_args[DefaultAgentArgs.SESSION_DISABLE_WEB_SECURITY]
|
|
207
|
+
operations.append(
|
|
208
|
+
lambda session: session.disable_web_security()
|
|
209
|
+
if disable_web_security
|
|
210
|
+
else session.enable_web_security()
|
|
211
|
+
)
|
|
212
|
+
del session_args[DefaultAgentArgs.SESSION_DISABLE_WEB_SECURITY]
|
|
213
|
+
|
|
214
|
+
session = session._copy_and_validate(**session_args)
|
|
215
|
+
for operation in operations:
|
|
216
|
+
session = operation(session)
|
|
217
|
+
return session
|
|
218
|
+
|
|
219
|
+
return cls(**agent_args).map_session(update_session)
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
from litellm import (
|
|
7
|
+
AllMessageValues,
|
|
8
|
+
ChatCompletionAssistantMessage,
|
|
9
|
+
ChatCompletionAssistantToolCall,
|
|
10
|
+
ChatCompletionImageObject,
|
|
11
|
+
ChatCompletionSystemMessage,
|
|
12
|
+
ChatCompletionTextObject,
|
|
13
|
+
ChatCompletionToolMessage,
|
|
14
|
+
ChatCompletionUserMessage,
|
|
15
|
+
ModelResponse, # type: ignore[reportPrivateImportUsage]
|
|
16
|
+
OpenAIMessageContent,
|
|
17
|
+
)
|
|
18
|
+
from litellm.utils import token_counter # type: ignore[reportUnknownVariableType]
|
|
19
|
+
from loguru import logger
|
|
20
|
+
from notte_core.errors.llm import LLMParsingError
|
|
21
|
+
from notte_core.llms.engine import LlmModel, StructuredContent
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
|
|
24
|
+
# Define valid message roles
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class CachedMessage:
|
|
29
|
+
"""Message with cached token count"""
|
|
30
|
+
|
|
31
|
+
message: AllMessageValues
|
|
32
|
+
token_count: int
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
T = TypeVar("T", bound=BaseModel)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class Conversation:
|
|
40
|
+
"""Manages conversation history and message extraction"""
|
|
41
|
+
|
|
42
|
+
history: list[CachedMessage] = field(default_factory=list)
|
|
43
|
+
json_extractor: StructuredContent = field(default_factory=lambda: StructuredContent(inner_tag="json"))
|
|
44
|
+
autosize: bool = False
|
|
45
|
+
model: str = LlmModel.default()
|
|
46
|
+
max_tokens: int | None = None
|
|
47
|
+
conservative_factor: float = 0.8
|
|
48
|
+
|
|
49
|
+
_total_tokens: int = field(default=0, init=False)
|
|
50
|
+
convert_tools_to_assistant: bool = False
|
|
51
|
+
|
|
52
|
+
def __post_init__(self) -> None:
|
|
53
|
+
if self.max_tokens is None:
|
|
54
|
+
self.max_tokens = LlmModel.context_length(self.model)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def default_max_tokens(self) -> int:
|
|
58
|
+
if self.max_tokens is None:
|
|
59
|
+
raise ValueError("max_tokens is not set")
|
|
60
|
+
return self.max_tokens
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def conservative_max_tokens(self) -> int:
|
|
64
|
+
"""Since token count isn't 100% accurate, allow to be
|
|
65
|
+
slightly conservative, to make sure we trim under the total context length"""
|
|
66
|
+
return int(self.default_max_tokens * self.conservative_factor)
|
|
67
|
+
|
|
68
|
+
def count_tokens(self, content: AllMessageValues) -> int:
|
|
69
|
+
"""Count the number of tokens in a list of messages"""
|
|
70
|
+
return token_counter(model=self.model, messages=[content])
|
|
71
|
+
|
|
72
|
+
def total_tokens(self) -> int:
|
|
73
|
+
"""Get total tokens in conversation history"""
|
|
74
|
+
return self._total_tokens
|
|
75
|
+
|
|
76
|
+
def trim_history_to_fit(self, new_content: AllMessageValues) -> None:
|
|
77
|
+
"""Trim history to make room for new content while preserving system messages"""
|
|
78
|
+
if not self.autosize:
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
# Always keep system messages
|
|
82
|
+
init_messages: list[CachedMessage] = []
|
|
83
|
+
other_messages: list[CachedMessage] = []
|
|
84
|
+
is_init_msg = True
|
|
85
|
+
for msg in self.history:
|
|
86
|
+
match is_init_msg, msg.message["role"]:
|
|
87
|
+
case True, "system":
|
|
88
|
+
init_messages.append(msg)
|
|
89
|
+
case True, "user":
|
|
90
|
+
# keep first user message as init message (need task description)
|
|
91
|
+
is_init_msg = False
|
|
92
|
+
init_messages.append(msg)
|
|
93
|
+
case _, _:
|
|
94
|
+
other_messages.append(msg)
|
|
95
|
+
|
|
96
|
+
new_content_tokens = self.count_tokens(new_content)
|
|
97
|
+
init_tokens = sum(msg.token_count for msg in init_messages)
|
|
98
|
+
available_tokens = self.conservative_max_tokens - init_tokens - new_content_tokens
|
|
99
|
+
|
|
100
|
+
# Remove oldest non-system messages until we have room
|
|
101
|
+
current_tokens = sum(msg.token_count for msg in other_messages)
|
|
102
|
+
has_trimmed = 0
|
|
103
|
+
while other_messages and current_tokens > available_tokens:
|
|
104
|
+
removed = other_messages.pop(0)
|
|
105
|
+
current_tokens -= removed.token_count
|
|
106
|
+
has_trimmed += 1
|
|
107
|
+
|
|
108
|
+
if has_trimmed > 0:
|
|
109
|
+
logger.info(
|
|
110
|
+
f"Trimmed {has_trimmed} message(s) to stay under max token limit (i.e {self.default_max_tokens // 1000}k)"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.history = init_messages + other_messages
|
|
114
|
+
self._total_tokens = sum(msg.token_count for msg in self.history)
|
|
115
|
+
|
|
116
|
+
def _add_message(self, msg: AllMessageValues) -> None:
|
|
117
|
+
"""Internal helper to add a message with token counting"""
|
|
118
|
+
token_count = self.count_tokens(msg)
|
|
119
|
+
if self.autosize:
|
|
120
|
+
self.trim_history_to_fit(msg)
|
|
121
|
+
cached_msg = CachedMessage(message=msg, token_count=token_count)
|
|
122
|
+
self.history.append(cached_msg)
|
|
123
|
+
self._total_tokens += token_count
|
|
124
|
+
|
|
125
|
+
def add_system_message(self, content: str) -> None:
|
|
126
|
+
"""Add a system message to the conversation"""
|
|
127
|
+
self._add_message(ChatCompletionSystemMessage(role="system", content=content))
|
|
128
|
+
|
|
129
|
+
def format_image_content(self, image: bytes) -> ChatCompletionImageObject:
|
|
130
|
+
image_str = base64.b64encode(image).decode("utf-8")
|
|
131
|
+
return ChatCompletionImageObject(
|
|
132
|
+
type="image_url",
|
|
133
|
+
image_url={"url": f"data:image/png;base64,{image_str}"},
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def format_user_contents(self, contents: list[str | bytes]) -> OpenAIMessageContent:
|
|
137
|
+
return [
|
|
138
|
+
(
|
|
139
|
+
ChatCompletionTextObject(type="text", text=content)
|
|
140
|
+
if isinstance(content, str)
|
|
141
|
+
else self.format_image_content(content)
|
|
142
|
+
)
|
|
143
|
+
for content in contents
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
def add_user_message(self, content: OpenAIMessageContent, image: bytes | None = None) -> None:
|
|
147
|
+
"""Add a user message to the conversation"""
|
|
148
|
+
_content: OpenAIMessageContent = content
|
|
149
|
+
if image is not None and isinstance(content, str):
|
|
150
|
+
_content = self.format_user_contents([content, image])
|
|
151
|
+
self._add_message(ChatCompletionUserMessage(role="user", content=_content))
|
|
152
|
+
|
|
153
|
+
def add_user_messages(self, contents: list[str | bytes]) -> None:
|
|
154
|
+
"""Add a user message to the conversation"""
|
|
155
|
+
_content: OpenAIMessageContent = self.format_user_contents(contents)
|
|
156
|
+
self._add_message(ChatCompletionUserMessage(role="user", content=_content))
|
|
157
|
+
|
|
158
|
+
def add_assistant_message(self, content: str) -> None:
|
|
159
|
+
"""Add an assistant message to the conversation"""
|
|
160
|
+
self._add_message(ChatCompletionAssistantMessage(role="assistant", content=content))
|
|
161
|
+
|
|
162
|
+
def add_tool_message(self, parsed_content: BaseModel, tool_id: str) -> None:
|
|
163
|
+
"""Add a tool message to the conversation"""
|
|
164
|
+
content: str = str(parsed_content.model_dump(mode="json", exclude_unset=True))
|
|
165
|
+
if not self.convert_tools_to_assistant:
|
|
166
|
+
self._add_message(
|
|
167
|
+
ChatCompletionToolMessage(
|
|
168
|
+
role="tool",
|
|
169
|
+
content=content,
|
|
170
|
+
tool_call_id=tool_id,
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
# Optional, convert tools to assistant role
|
|
175
|
+
self._add_message(
|
|
176
|
+
ChatCompletionAssistantMessage(
|
|
177
|
+
role="assistant",
|
|
178
|
+
content="",
|
|
179
|
+
tool_calls=[
|
|
180
|
+
ChatCompletionAssistantToolCall(
|
|
181
|
+
id=tool_id,
|
|
182
|
+
type="function",
|
|
183
|
+
function={
|
|
184
|
+
"arguments": content,
|
|
185
|
+
"name": parsed_content.__class__.__name__,
|
|
186
|
+
},
|
|
187
|
+
)
|
|
188
|
+
],
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def parse_structured_response(self, response: ModelResponse | str, model: type[T]) -> T:
|
|
193
|
+
"""Parse a structured response from the LLM into a Pydantic model
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
response: The LLM model response
|
|
197
|
+
model: The Pydantic model class to parse into
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Instance of the specified Pydantic model
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
LLMParsingError: If response cannot be parsed into the model
|
|
204
|
+
"""
|
|
205
|
+
if isinstance(response, str):
|
|
206
|
+
return model.model_validate(response)
|
|
207
|
+
if not response.choices:
|
|
208
|
+
raise LLMParsingError("No choices in LLM response")
|
|
209
|
+
|
|
210
|
+
choice = response.choices[0]
|
|
211
|
+
# Extract content from either streaming or non-streaming response
|
|
212
|
+
content: str | None = None
|
|
213
|
+
if isinstance(choice, dict):
|
|
214
|
+
message = choice.get("message", {}) # type: ignore[reportUnknownMemberType]
|
|
215
|
+
if isinstance(message, dict):
|
|
216
|
+
content = message.get("content") # type: ignore[reportUnknownMemberType]
|
|
217
|
+
else:
|
|
218
|
+
content = getattr(choice, "text")
|
|
219
|
+
|
|
220
|
+
if not content:
|
|
221
|
+
raise LLMParsingError("No content in LLM response message")
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
if content is None or not isinstance(content, str):
|
|
225
|
+
raise LLMParsingError("No content in LLM response message")
|
|
226
|
+
extracted = self.json_extractor.extract(content)
|
|
227
|
+
return model.model_validate_json(extracted)
|
|
228
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
229
|
+
raise LLMParsingError(f"Failed to parse response into {model.__name__}: {str(e)}")
|
|
230
|
+
|
|
231
|
+
def messages(self) -> list[AllMessageValues]:
|
|
232
|
+
"""Get messages in LiteLLM format
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
List of messages formatted for LiteLLM
|
|
236
|
+
|
|
237
|
+
Note:
|
|
238
|
+
This converts our internal message format to litellm's format.
|
|
239
|
+
litellm only supports 'assistant' role, so we map all roles to that.
|
|
240
|
+
"""
|
|
241
|
+
return [msg.message for msg in self.history]
|
|
242
|
+
|
|
243
|
+
def reset(self) -> None:
|
|
244
|
+
"""Clear all messages from the conversation"""
|
|
245
|
+
self.history.clear()
|
|
246
|
+
self._total_tokens = 0
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from notte_agent.common.base import BaseAgent
|
|
6
|
+
from notte_agent.common.types import AgentResponse
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseNotifier(ABC):
|
|
10
|
+
"""Base class for notification implementations."""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
async def send_message(self, text: str) -> None:
|
|
14
|
+
"""Send a message using the specific notification service."""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
async def notify(self, task: str, result: AgentResponse) -> None:
|
|
18
|
+
"""Send a notification about the task result.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
task: The task description
|
|
22
|
+
result: The agent's response to be sent
|
|
23
|
+
"""
|
|
24
|
+
message = f"""
|
|
25
|
+
Notte Agent Report 🌙
|
|
26
|
+
|
|
27
|
+
Task Details:
|
|
28
|
+
-------------
|
|
29
|
+
Task: {task}
|
|
30
|
+
Execution Time: {round(result.duration_in_s, 2)} seconds
|
|
31
|
+
Status: {"✅ Success" if result.success else "❌ Failed"}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
Agent Response:
|
|
35
|
+
--------------
|
|
36
|
+
{result.answer}
|
|
37
|
+
|
|
38
|
+
Powered by Notte 🌒"""
|
|
39
|
+
await self.send_message(text=message)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NotifierAgent(BaseAgent):
|
|
43
|
+
"""Agent wrapper that sends notifications after task completion."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, agent: BaseAgent, notifier: BaseNotifier):
|
|
46
|
+
super().__init__(session=agent.session)
|
|
47
|
+
self.agent: BaseAgent = agent
|
|
48
|
+
self.notifier: BaseNotifier = notifier
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
async def run(self, task: str, url: str | None = None) -> AgentResponse:
|
|
52
|
+
"""Run the agent and send notification about the result."""
|
|
53
|
+
result = await self.agent.run(task, url)
|
|
54
|
+
await self.notifier.notify(task, result)
|
|
55
|
+
return result
|