beeai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beeai_framework/__init__.py +42 -0
- beeai_framework/adapters/__init__.py +1 -0
- beeai_framework/adapters/litellm/chat.py +158 -0
- beeai_framework/adapters/ollama/__init__.py +1 -0
- beeai_framework/adapters/ollama/backend/__init__.py +1 -0
- beeai_framework/adapters/ollama/backend/chat.py +19 -0
- beeai_framework/adapters/watsonx/__init__.py +1 -0
- beeai_framework/adapters/watsonx/backend/__init__.py +1 -0
- beeai_framework/adapters/watsonx/backend/chat.py +19 -0
- beeai_framework/agents/__init__.py +6 -0
- beeai_framework/agents/base.py +60 -0
- beeai_framework/agents/bee/__init__.py +5 -0
- beeai_framework/agents/bee/agent.py +165 -0
- beeai_framework/agents/errors.py +10 -0
- beeai_framework/agents/runners/base.py +130 -0
- beeai_framework/agents/runners/default/prompts.py +152 -0
- beeai_framework/agents/runners/default/runner.py +185 -0
- beeai_framework/agents/runners/granite/prompts.py +94 -0
- beeai_framework/agents/runners/granite/runner.py +67 -0
- beeai_framework/agents/types.py +89 -0
- beeai_framework/backend/__init__.py +25 -0
- beeai_framework/backend/chat.py +266 -0
- beeai_framework/backend/constants.py +28 -0
- beeai_framework/backend/errors.py +25 -0
- beeai_framework/backend/message.py +141 -0
- beeai_framework/backend/utils.py +66 -0
- beeai_framework/cancellation.py +76 -0
- beeai_framework/context.py +160 -0
- beeai_framework/emitter/__init__.py +27 -0
- beeai_framework/emitter/emitter.py +192 -0
- beeai_framework/emitter/errors.py +10 -0
- beeai_framework/emitter/types.py +24 -0
- beeai_framework/emitter/utils.py +17 -0
- beeai_framework/errors.py +179 -0
- beeai_framework/llms/__init__.py +27 -0
- beeai_framework/llms/base_output.py +27 -0
- beeai_framework/llms/llm.py +133 -0
- beeai_framework/llms/output.py +29 -0
- beeai_framework/memory/__init__.py +34 -0
- beeai_framework/memory/base_cache.py +109 -0
- beeai_framework/memory/base_memory.py +83 -0
- beeai_framework/memory/errors.py +32 -0
- beeai_framework/memory/file_cache.py +238 -0
- beeai_framework/memory/readonly_memory.py +34 -0
- beeai_framework/memory/serializable.py +97 -0
- beeai_framework/memory/serializer.py +254 -0
- beeai_framework/memory/sliding_cache.py +111 -0
- beeai_framework/memory/sliding_memory.py +129 -0
- beeai_framework/memory/summarize_memory.py +76 -0
- beeai_framework/memory/task_map.py +144 -0
- beeai_framework/memory/token_memory.py +123 -0
- beeai_framework/memory/unconstrained_cache.py +161 -0
- beeai_framework/memory/unconstrained_memory.py +37 -0
- beeai_framework/parsers/line_prefix.py +55 -0
- beeai_framework/tools/__init__.py +18 -0
- beeai_framework/tools/errors.py +13 -0
- beeai_framework/tools/mcp_tools.py +80 -0
- beeai_framework/tools/search/__init__.py +20 -0
- beeai_framework/tools/search/base.py +42 -0
- beeai_framework/tools/search/duckduckgo.py +56 -0
- beeai_framework/tools/search/wikipedia.py +15 -0
- beeai_framework/tools/tool.py +133 -0
- beeai_framework/tools/weather/openmeteo.py +124 -0
- beeai_framework/utils/__init__.py +8 -0
- beeai_framework/utils/_types.py +21 -0
- beeai_framework/utils/config.py +18 -0
- beeai_framework/utils/counter.py +37 -0
- beeai_framework/utils/custom_logger.py +107 -0
- beeai_framework/utils/errors.py +17 -0
- beeai_framework/utils/events.py +11 -0
- beeai_framework/utils/models.py +34 -0
- beeai_framework/utils/regex.py +11 -0
- beeai_framework/utils/templates.py +41 -0
- beeai_framework/workflows/__init__.py +18 -0
- beeai_framework/workflows/agent.py +110 -0
- beeai_framework/workflows/errors.py +8 -0
- beeai_framework/workflows/workflow.py +159 -0
- beeai_framework-0.1.0.dist-info/LICENSE +201 -0
- beeai_framework-0.1.0.dist-info/METADATA +172 -0
- beeai_framework-0.1.0.dist-info/RECORD +82 -0
- beeai_framework-0.1.0.dist-info/WHEEL +4 -0
- beeai_framework-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
from beeai_framework.agents import BaseAgent
|
|
3
|
+
from beeai_framework.agents.bee.agent import BeeAgent
|
|
4
|
+
from beeai_framework.backend import (
|
|
5
|
+
AssistantMessage,
|
|
6
|
+
CustomMessage,
|
|
7
|
+
Message,
|
|
8
|
+
Role,
|
|
9
|
+
SystemMessage,
|
|
10
|
+
ToolMessage,
|
|
11
|
+
UserMessage,
|
|
12
|
+
)
|
|
13
|
+
from beeai_framework.llms import LLM, AgentInput, BaseLLM
|
|
14
|
+
from beeai_framework.memory import BaseMemory, ReadOnlyMemory, TokenMemory, UnconstrainedMemory
|
|
15
|
+
from beeai_framework.memory.serializable import Serializable
|
|
16
|
+
from beeai_framework.tools import Tool, tool
|
|
17
|
+
from beeai_framework.tools.weather.openmeteo import OpenMeteoTool
|
|
18
|
+
from beeai_framework.utils.templates import Prompt
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"LLM",
|
|
22
|
+
"AgentInput",
|
|
23
|
+
"AssistantMessage",
|
|
24
|
+
"BaseAgent",
|
|
25
|
+
"BaseLLM",
|
|
26
|
+
"BaseMemory",
|
|
27
|
+
"BeeAgent",
|
|
28
|
+
"CustomMessage",
|
|
29
|
+
"Message",
|
|
30
|
+
"OpenMeteoTool",
|
|
31
|
+
"Prompt",
|
|
32
|
+
"ReadOnlyMemory",
|
|
33
|
+
"Role",
|
|
34
|
+
"Serializable",
|
|
35
|
+
"SystemMessage",
|
|
36
|
+
"TokenMemory",
|
|
37
|
+
"Tool",
|
|
38
|
+
"ToolMessage",
|
|
39
|
+
"UnconstrainedMemory",
|
|
40
|
+
"UserMessage",
|
|
41
|
+
"tool",
|
|
42
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import litellm
|
|
8
|
+
from litellm import (
|
|
9
|
+
ModelResponse,
|
|
10
|
+
ModelResponseStream,
|
|
11
|
+
acompletion,
|
|
12
|
+
get_supported_openai_params,
|
|
13
|
+
)
|
|
14
|
+
from pydantic import BaseModel, ConfigDict
|
|
15
|
+
|
|
16
|
+
from beeai_framework.backend.chat import (
|
|
17
|
+
ChatModel,
|
|
18
|
+
ChatModelInput,
|
|
19
|
+
ChatModelOutput,
|
|
20
|
+
ChatModelStructureInput,
|
|
21
|
+
ChatModelStructureOutput,
|
|
22
|
+
)
|
|
23
|
+
from beeai_framework.backend.errors import ChatModelError
|
|
24
|
+
from beeai_framework.backend.message import AssistantMessage, Message, Role, ToolMessage
|
|
25
|
+
from beeai_framework.backend.utils import parse_broken_json
|
|
26
|
+
from beeai_framework.context import RunContext
|
|
27
|
+
from beeai_framework.tools.tool import Tool
|
|
28
|
+
from beeai_framework.utils.custom_logger import BeeLogger
|
|
29
|
+
|
|
30
|
+
logger = BeeLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LiteLLMParameters(BaseModel):
|
|
34
|
+
model: str
|
|
35
|
+
messages: list[dict[str, Any]]
|
|
36
|
+
tools: list[dict[str, Any]] | None = None
|
|
37
|
+
response_format: dict[str, Any] | type[BaseModel] | None = None
|
|
38
|
+
|
|
39
|
+
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LiteLLMChatModel(ChatModel):
|
|
43
|
+
@property
|
|
44
|
+
def model_id(self) -> str:
|
|
45
|
+
return self._model_id
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def provider_id(self) -> str:
|
|
49
|
+
return self._provider_id
|
|
50
|
+
|
|
51
|
+
def __init__(self, model_id: str | None = None, **settings: Any) -> None:
|
|
52
|
+
llm_provider = "ollama_chat" if self.provider_id == "ollama" else self.provider_id
|
|
53
|
+
self.supported_params = get_supported_openai_params(model=self.model_id, custom_llm_provider=llm_provider)
|
|
54
|
+
# drop any unsupported parameters that were passed in
|
|
55
|
+
litellm.drop_params = True
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
async def _create(
|
|
59
|
+
self,
|
|
60
|
+
input: ChatModelInput,
|
|
61
|
+
run: RunContext,
|
|
62
|
+
) -> ChatModelOutput:
|
|
63
|
+
litellm_input = self._transform_input(input)
|
|
64
|
+
response = await acompletion(**litellm_input.model_dump())
|
|
65
|
+
response_message = response.get("choices", [{}])[0].get("message", {})
|
|
66
|
+
response_content = response_message.get("content", "")
|
|
67
|
+
tool_calls = response_message.tool_calls
|
|
68
|
+
|
|
69
|
+
if tool_calls:
|
|
70
|
+
litellm_input.messages.append({"role": Role.ASSISTANT, "content": response_content})
|
|
71
|
+
for tool_call in tool_calls:
|
|
72
|
+
function_name = tool_call.function.name
|
|
73
|
+
function_to_call: Tool = next(filter(lambda t: t.name == function_name, input.tools))
|
|
74
|
+
|
|
75
|
+
function_args = json.loads(tool_call.function.arguments)
|
|
76
|
+
function_response = function_to_call.run(input=function_args)
|
|
77
|
+
litellm_input.messages.append({"role": Role.TOOL, "content": function_response})
|
|
78
|
+
|
|
79
|
+
response = await acompletion(**litellm_input.model_dump())
|
|
80
|
+
|
|
81
|
+
response_output = self._transform_output(response)
|
|
82
|
+
logger.trace(f"Inference response output:\n{response_output}")
|
|
83
|
+
return response_output
|
|
84
|
+
|
|
85
|
+
async def _create_stream(self, input: ChatModelInput, _: RunContext) -> AsyncGenerator[ChatModelOutput]:
|
|
86
|
+
litellm_input = self._transform_input(input)
|
|
87
|
+
parameters = litellm_input.model_dump()
|
|
88
|
+
parameters["stream"] = True
|
|
89
|
+
response = await acompletion(**parameters)
|
|
90
|
+
|
|
91
|
+
# TODO: handle tool calling for streaming
|
|
92
|
+
async for chunk in response:
|
|
93
|
+
response_output = self._transform_output(chunk)
|
|
94
|
+
if not response_output:
|
|
95
|
+
continue
|
|
96
|
+
yield response_output
|
|
97
|
+
|
|
98
|
+
async def _create_structure(self, input: ChatModelStructureInput, run: RunContext) -> ChatModelStructureOutput:
|
|
99
|
+
if "response_format" not in self.supported_params:
|
|
100
|
+
logger.warning(f"{self.provider_id} model {self.model_id} does not support structured data.")
|
|
101
|
+
return await super()._create_structure(input, run)
|
|
102
|
+
else:
|
|
103
|
+
response = await self._create(
|
|
104
|
+
ChatModelInput(messages=input.messages, response_format=input.schema, abort_signal=input.abort_signal),
|
|
105
|
+
run,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
logger.trace(f"Structured response received:\n{response}")
|
|
109
|
+
|
|
110
|
+
text_response = response.get_text_content()
|
|
111
|
+
result = parse_broken_json(text_response)
|
|
112
|
+
# TODO: validate result matches expected schema
|
|
113
|
+
return ChatModelStructureOutput(object=result)
|
|
114
|
+
|
|
115
|
+
def _get_model_name(self) -> str:
|
|
116
|
+
return f"{'ollama_chat' if self.provider_id == 'ollama' else self.provider_id}/{self.model_id}"
|
|
117
|
+
|
|
118
|
+
def _transform_input(self, input: ChatModelInput) -> LiteLLMParameters:
|
|
119
|
+
messages_list = [message.to_plain() for message in input.messages]
|
|
120
|
+
|
|
121
|
+
if input.tools:
|
|
122
|
+
prepared_tools_list = [{"type": "function", "function": tool.prompt_data()} for tool in input.tools]
|
|
123
|
+
else:
|
|
124
|
+
prepared_tools_list = None
|
|
125
|
+
|
|
126
|
+
model = self._get_model_name()
|
|
127
|
+
|
|
128
|
+
return LiteLLMParameters(
|
|
129
|
+
model=model,
|
|
130
|
+
messages=messages_list,
|
|
131
|
+
tools=prepared_tools_list,
|
|
132
|
+
response_format=input.response_format,
|
|
133
|
+
**self.settings,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _transform_output(self, chunk: ModelResponse | ModelResponseStream) -> ChatModelOutput:
|
|
137
|
+
choice = chunk.get("choices", [{}])[0]
|
|
138
|
+
finish_reason = choice.get("finish_reason")
|
|
139
|
+
message: Message | None = None
|
|
140
|
+
usage = choice.get("usage")
|
|
141
|
+
|
|
142
|
+
if isinstance(chunk, ModelResponseStream):
|
|
143
|
+
if finish_reason:
|
|
144
|
+
return None
|
|
145
|
+
content = choice.get("delta", {}).get("content")
|
|
146
|
+
if choice.get("tool_calls"):
|
|
147
|
+
message = ToolMessage(content)
|
|
148
|
+
elif choice.get("delta"):
|
|
149
|
+
message = AssistantMessage(content)
|
|
150
|
+
else:
|
|
151
|
+
# TODO: handle other possible types
|
|
152
|
+
raise ChatModelError(f"Unhandled event: {choice}")
|
|
153
|
+
else:
|
|
154
|
+
response_message = choice.get("message")
|
|
155
|
+
content = response_message.get("content")
|
|
156
|
+
message = AssistantMessage(content)
|
|
157
|
+
|
|
158
|
+
return ChatModelOutput(messages=[message], finish_reason=finish_reason, usage=usage)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from beeai_framework.adapters.litellm.chat import LiteLLMChatModel
|
|
7
|
+
from beeai_framework.backend.constants import ProviderName
|
|
8
|
+
from beeai_framework.utils.custom_logger import BeeLogger
|
|
9
|
+
|
|
10
|
+
logger = BeeLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OllamaChatModel(LiteLLMChatModel):
|
|
14
|
+
provider_id: ProviderName = "ollama"
|
|
15
|
+
|
|
16
|
+
def __init__(self, model_id: str | None = None, **settings: Any) -> None:
|
|
17
|
+
self._model_id = model_id if model_id else os.getenv("OLLAMA_CHAT_MODEL", "llama3.1:8b")
|
|
18
|
+
self.settings = {"base_url": "http://localhost:11434"} | settings
|
|
19
|
+
super().__init__()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from beeai_framework.adapters.litellm.chat import LiteLLMChatModel
|
|
7
|
+
from beeai_framework.backend.constants import ProviderName
|
|
8
|
+
from beeai_framework.utils.custom_logger import BeeLogger
|
|
9
|
+
|
|
10
|
+
logger = BeeLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class WatsonxChatModel(LiteLLMChatModel):
|
|
14
|
+
provider_id: ProviderName = "watsonx"
|
|
15
|
+
|
|
16
|
+
def __init__(self, model_id: str | None = None, **settings: Any) -> None:
|
|
17
|
+
self._model_id = model_id if model_id else os.getenv("WATSONX_CHAT_MODEL", "ibm/granite-3-8b-instruct")
|
|
18
|
+
self.settings = settings
|
|
19
|
+
super().__init__()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from beeai_framework.agents.types import AgentMeta, BeeRunInput, BeeRunOptions, BeeRunOutput
|
|
6
|
+
from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance
|
|
7
|
+
from beeai_framework.emitter import Emitter
|
|
8
|
+
from beeai_framework.memory import BaseMemory
|
|
9
|
+
from beeai_framework.utils.models import ModelLike, to_model, to_model_optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseAgent(ABC):
|
|
13
|
+
is_running: bool = False
|
|
14
|
+
emitter: Emitter = None
|
|
15
|
+
|
|
16
|
+
def run(self, run_input: ModelLike[BeeRunInput], options: ModelLike[BeeRunOptions] | None = None) -> Run:
|
|
17
|
+
run_input = to_model(BeeRunInput, run_input)
|
|
18
|
+
options = to_model_optional(BeeRunOptions, options)
|
|
19
|
+
|
|
20
|
+
if self.is_running:
|
|
21
|
+
raise RuntimeError("Agent is already running!")
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
return RunContext.enter(
|
|
25
|
+
RunInstance(emitter=self.emitter),
|
|
26
|
+
RunContextInput(signal=options.signal if options else None, params=(run_input, options)),
|
|
27
|
+
lambda context: self._run(run_input, options, context),
|
|
28
|
+
)
|
|
29
|
+
except Exception as e:
|
|
30
|
+
if isinstance(e, RuntimeError):
|
|
31
|
+
raise e
|
|
32
|
+
else:
|
|
33
|
+
raise RuntimeError("Error has occurred!") from e
|
|
34
|
+
finally:
|
|
35
|
+
self.is_running = False
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def _run(self, run_input: BeeRunInput, options: BeeRunOptions | None, context: RunContext) -> BeeRunOutput:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def destroy(self) -> None:
|
|
42
|
+
self.emitter.destroy()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def memory(self) -> BaseMemory:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@memory.setter
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def memory(self, memory: BaseMemory) -> None:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def meta(self) -> AgentMeta:
|
|
56
|
+
return AgentMeta(
|
|
57
|
+
name=self.__class__.__name__,
|
|
58
|
+
description="",
|
|
59
|
+
tools=[],
|
|
60
|
+
)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from datetime import UTC, datetime
|
|
5
|
+
|
|
6
|
+
from beeai_framework.agents.base import BaseAgent
|
|
7
|
+
from beeai_framework.agents.runners.base import (
|
|
8
|
+
BaseRunner,
|
|
9
|
+
BeeRunnerToolInput,
|
|
10
|
+
BeeRunnerToolResult,
|
|
11
|
+
RunnerIteration,
|
|
12
|
+
)
|
|
13
|
+
from beeai_framework.agents.runners.default.runner import DefaultRunner
|
|
14
|
+
from beeai_framework.agents.runners.granite.runner import GraniteRunner
|
|
15
|
+
from beeai_framework.agents.types import (
|
|
16
|
+
AgentMeta,
|
|
17
|
+
BeeAgentExecutionConfig,
|
|
18
|
+
BeeInput,
|
|
19
|
+
BeeRunInput,
|
|
20
|
+
BeeRunOptions,
|
|
21
|
+
BeeRunOutput,
|
|
22
|
+
)
|
|
23
|
+
from beeai_framework.backend import Message
|
|
24
|
+
from beeai_framework.backend.message import AssistantMessage, MessageMeta, UserMessage
|
|
25
|
+
from beeai_framework.context import RunContext
|
|
26
|
+
from beeai_framework.emitter import Emitter, EmitterInput
|
|
27
|
+
from beeai_framework.memory import BaseMemory
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BeeAgent(BaseAgent):
|
|
31
|
+
runner: Callable[..., BaseRunner]
|
|
32
|
+
|
|
33
|
+
def __init__(self, bee_input: BeeInput) -> None:
|
|
34
|
+
self.input = bee_input
|
|
35
|
+
if "granite" in self.input.llm.model_id:
|
|
36
|
+
self.runner = GraniteRunner
|
|
37
|
+
else:
|
|
38
|
+
self.runner = DefaultRunner
|
|
39
|
+
self.emitter = Emitter.root().child(
|
|
40
|
+
EmitterInput(
|
|
41
|
+
namespace=["agent", "bee"],
|
|
42
|
+
creator=self,
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def memory(self) -> BaseMemory:
|
|
48
|
+
return self.input.memory
|
|
49
|
+
|
|
50
|
+
@memory.setter
|
|
51
|
+
def memory(self, memory: BaseMemory) -> None:
|
|
52
|
+
self.input.memory = memory
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def meta(self) -> AgentMeta:
|
|
56
|
+
tools = self.input.tools[:]
|
|
57
|
+
|
|
58
|
+
if self.input.meta:
|
|
59
|
+
return AgentMeta(
|
|
60
|
+
name=self.input.meta.name,
|
|
61
|
+
description=self.input.meta.description,
|
|
62
|
+
extra_description=self.input.meta.extra_description,
|
|
63
|
+
tools=tools,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
extra_description = ["Tools that I can use to accomplish given task."]
|
|
67
|
+
for tool in tools:
|
|
68
|
+
extra_description.append(f"Tool ${tool.name}': ${tool.description}.")
|
|
69
|
+
|
|
70
|
+
return AgentMeta(
|
|
71
|
+
name="BeeAI",
|
|
72
|
+
tools=tools,
|
|
73
|
+
description="The BeeAI framework demonstrates its ability to auto-correct and adapt in real-time, improving"
|
|
74
|
+
" the overall reliability and resilience of the system.",
|
|
75
|
+
extra_description="\n".join(extra_description) if len(tools) > 0 else None,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
async def _run(self, run_input: BeeRunInput, options: BeeRunOptions | None, context: RunContext) -> BeeRunOutput:
|
|
79
|
+
runner = self.runner(
|
|
80
|
+
self.input,
|
|
81
|
+
(
|
|
82
|
+
options
|
|
83
|
+
if options
|
|
84
|
+
else BeeRunOptions(
|
|
85
|
+
execution=self.input.execution
|
|
86
|
+
or (options.execution if options is not None else None)
|
|
87
|
+
or BeeAgentExecutionConfig(
|
|
88
|
+
max_retries_per_step=3,
|
|
89
|
+
total_max_retries=20,
|
|
90
|
+
max_iterations=10,
|
|
91
|
+
),
|
|
92
|
+
signal=None,
|
|
93
|
+
)
|
|
94
|
+
),
|
|
95
|
+
context,
|
|
96
|
+
)
|
|
97
|
+
await runner.init(run_input)
|
|
98
|
+
|
|
99
|
+
final_message: Message | None = None
|
|
100
|
+
while not final_message:
|
|
101
|
+
iteration: RunnerIteration = await runner.create_iteration()
|
|
102
|
+
|
|
103
|
+
if iteration.state.tool_name and iteration.state.tool_input:
|
|
104
|
+
tool_result: BeeRunnerToolResult = await runner.tool(
|
|
105
|
+
input=BeeRunnerToolInput(
|
|
106
|
+
state=iteration.state,
|
|
107
|
+
emitter=iteration.emitter,
|
|
108
|
+
meta=iteration.meta,
|
|
109
|
+
signal=iteration.signal,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
await runner.memory.add(
|
|
113
|
+
AssistantMessage(
|
|
114
|
+
content=runner.templates.assistant.render(
|
|
115
|
+
{
|
|
116
|
+
"thought": iteration.state.thought,
|
|
117
|
+
"tool_name": iteration.state.tool_name,
|
|
118
|
+
"tool_input": iteration.state.tool_input,
|
|
119
|
+
"tool_output": tool_result.output.to_string(),
|
|
120
|
+
"final_answer": iteration.state.final_answer,
|
|
121
|
+
}
|
|
122
|
+
),
|
|
123
|
+
meta=MessageMeta({"success": tool_result.success}),
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
iteration.state.tool_output = tool_result.output.get_text_content()
|
|
127
|
+
|
|
128
|
+
for key in ["partialUpdate", "update"]:
|
|
129
|
+
await iteration.emitter.emit(
|
|
130
|
+
key,
|
|
131
|
+
{
|
|
132
|
+
"data": iteration.state,
|
|
133
|
+
"update": {
|
|
134
|
+
"key": "tool_output",
|
|
135
|
+
"value": tool_result.output,
|
|
136
|
+
"parsedValue": tool_result.output.to_string(),
|
|
137
|
+
},
|
|
138
|
+
"meta": {"success": tool_result.success}, # TODO deleted meta
|
|
139
|
+
"memory": runner.memory,
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if iteration.state.final_answer:
|
|
144
|
+
final_message = AssistantMessage(
|
|
145
|
+
content=iteration.state.final_answer, meta=MessageMeta({"createdAt": datetime.now(tz=UTC)})
|
|
146
|
+
)
|
|
147
|
+
await runner.memory.add(final_message)
|
|
148
|
+
await iteration.emitter.emit(
|
|
149
|
+
"success",
|
|
150
|
+
{
|
|
151
|
+
"data": final_message,
|
|
152
|
+
"iterations": runner.iterations,
|
|
153
|
+
"memory": runner.memory,
|
|
154
|
+
"meta": iteration.meta,
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if run_input.prompt is not None:
|
|
159
|
+
await self.input.memory.add(
|
|
160
|
+
UserMessage(content=run_input.prompt, meta=MessageMeta({"createdAt": context.created_at}))
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
await self.input.memory.add(final_message)
|
|
164
|
+
|
|
165
|
+
return BeeRunOutput(result=final_message, iterations=runner.iterations, memory=runner.memory)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from beeai_framework.errors import FrameworkError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AgentError(FrameworkError):
|
|
7
|
+
"""Raised for errors caused by agents."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, message: str = "Agent error", *, cause: Exception | None = None) -> None:
|
|
10
|
+
super().__init__(message, is_fatal=True, is_retryable=False, cause=cause)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from beeai_framework.agents.types import (
|
|
7
|
+
BeeAgentRunIteration,
|
|
8
|
+
BeeAgentTemplates,
|
|
9
|
+
BeeInput,
|
|
10
|
+
BeeIterationResult,
|
|
11
|
+
BeeMeta,
|
|
12
|
+
BeeRunInput,
|
|
13
|
+
BeeRunOptions,
|
|
14
|
+
)
|
|
15
|
+
from beeai_framework.cancellation import AbortSignal
|
|
16
|
+
from beeai_framework.context import RunContext
|
|
17
|
+
from beeai_framework.emitter.emitter import Emitter
|
|
18
|
+
from beeai_framework.emitter.types import EmitterInput
|
|
19
|
+
from beeai_framework.memory.base_memory import BaseMemory
|
|
20
|
+
from beeai_framework.tools import ToolOutput
|
|
21
|
+
from beeai_framework.utils.counter import RetryCounter
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class BeeRunnerLLMInput:
|
|
26
|
+
meta: BeeMeta
|
|
27
|
+
signal: AbortSignal
|
|
28
|
+
emitter: Emitter
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class RunnerIteration:
|
|
33
|
+
emitter: Emitter
|
|
34
|
+
state: BeeIterationResult
|
|
35
|
+
meta: BeeMeta
|
|
36
|
+
signal: AbortSignal
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class BeeRunnerToolResult:
|
|
41
|
+
output: ToolOutput
|
|
42
|
+
success: bool
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class BeeRunnerToolInput:
|
|
47
|
+
state: BeeIterationResult # TODO BeeIterationToolResult
|
|
48
|
+
meta: BeeMeta
|
|
49
|
+
signal: AbortSignal
|
|
50
|
+
emitter: Emitter
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class BaseRunner(ABC):
|
|
54
|
+
def __init__(self, input: BeeInput, options: BeeRunOptions, run: RunContext) -> None:
|
|
55
|
+
self._input = input
|
|
56
|
+
self._options = options
|
|
57
|
+
self._failed_attempts_counter = RetryCounter(
|
|
58
|
+
max_retries=(
|
|
59
|
+
options.execution.max_iterations if options.execution and options.execution.max_iterations else 0
|
|
60
|
+
),
|
|
61
|
+
error_type=Exception, # TODO Specific error type
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self._memory: BaseMemory | None = None
|
|
65
|
+
|
|
66
|
+
self._iterations: list[BeeAgentRunIteration] = []
|
|
67
|
+
self._failedAttemptsCounter: RetryCounter = RetryCounter(
|
|
68
|
+
error_type=Exception, # TODO AgentError
|
|
69
|
+
max_retries=(
|
|
70
|
+
options.execution.total_max_retries if options.execution and options.execution.total_max_retries else 0
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
self._run = run
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def iterations(self) -> list[BeeAgentRunIteration]:
|
|
77
|
+
return self._iterations
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def memory(self) -> BaseMemory:
|
|
81
|
+
if self._memory is not None:
|
|
82
|
+
return self._memory
|
|
83
|
+
raise Exception("Memory has not been initialized.")
|
|
84
|
+
|
|
85
|
+
async def create_iteration(self) -> RunnerIteration:
|
|
86
|
+
meta: BeeMeta = BeeMeta(iteration=len(self._iterations) + 1)
|
|
87
|
+
max_iterations = (
|
|
88
|
+
self._options.execution.max_iterations
|
|
89
|
+
if self._options.execution and self._options.execution.max_iterations
|
|
90
|
+
else 0
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if meta.iteration > max_iterations:
|
|
94
|
+
# TODO: Raise Agent Error with metadata
|
|
95
|
+
# https://github.com/i-am-bee/beeai-framework/blob/aa4d5e6091ed3bab8096492707ceb03d3b03863b/src/agents/bee/runners/base.ts#L70
|
|
96
|
+
raise Exception(f"Agent was not able to resolve the task in {max_iterations} iterations.")
|
|
97
|
+
|
|
98
|
+
emitter = self._run.emitter.child(emitter_input=EmitterInput(group_id=f"`iteration-${meta.iteration}"))
|
|
99
|
+
iteration: BeeAgentRunIteration = await self.llm(
|
|
100
|
+
BeeRunnerLLMInput(emitter=emitter, signal=self._run.signal, meta=meta)
|
|
101
|
+
)
|
|
102
|
+
self._iterations.append(iteration)
|
|
103
|
+
|
|
104
|
+
return RunnerIteration(emitter=emitter, state=iteration.state, meta=meta, signal=self._run.signal)
|
|
105
|
+
|
|
106
|
+
async def init(self, input: BeeRunInput) -> None:
|
|
107
|
+
self._memory = await self.init_memory(input)
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
async def llm(self, input: BeeRunnerLLMInput) -> BeeAgentRunIteration:
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
async def tool(self, input: BeeRunnerToolInput) -> BeeRunnerToolResult:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def default_templates(self) -> BeeAgentTemplates:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
async def init_memory(self, input: BeeRunInput) -> BaseMemory:
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def templates(self) -> BeeAgentTemplates:
|
|
127
|
+
# TODO: overrides
|
|
128
|
+
return self.default_templates()
|
|
129
|
+
|
|
130
|
+
# TODO: Serialization
|