beeai-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. beeai_framework/__init__.py +42 -0
  2. beeai_framework/adapters/__init__.py +1 -0
  3. beeai_framework/adapters/litellm/chat.py +158 -0
  4. beeai_framework/adapters/ollama/__init__.py +1 -0
  5. beeai_framework/adapters/ollama/backend/__init__.py +1 -0
  6. beeai_framework/adapters/ollama/backend/chat.py +19 -0
  7. beeai_framework/adapters/watsonx/__init__.py +1 -0
  8. beeai_framework/adapters/watsonx/backend/__init__.py +1 -0
  9. beeai_framework/adapters/watsonx/backend/chat.py +19 -0
  10. beeai_framework/agents/__init__.py +6 -0
  11. beeai_framework/agents/base.py +60 -0
  12. beeai_framework/agents/bee/__init__.py +5 -0
  13. beeai_framework/agents/bee/agent.py +165 -0
  14. beeai_framework/agents/errors.py +10 -0
  15. beeai_framework/agents/runners/base.py +130 -0
  16. beeai_framework/agents/runners/default/prompts.py +152 -0
  17. beeai_framework/agents/runners/default/runner.py +185 -0
  18. beeai_framework/agents/runners/granite/prompts.py +94 -0
  19. beeai_framework/agents/runners/granite/runner.py +67 -0
  20. beeai_framework/agents/types.py +89 -0
  21. beeai_framework/backend/__init__.py +25 -0
  22. beeai_framework/backend/chat.py +266 -0
  23. beeai_framework/backend/constants.py +28 -0
  24. beeai_framework/backend/errors.py +25 -0
  25. beeai_framework/backend/message.py +141 -0
  26. beeai_framework/backend/utils.py +66 -0
  27. beeai_framework/cancellation.py +76 -0
  28. beeai_framework/context.py +160 -0
  29. beeai_framework/emitter/__init__.py +27 -0
  30. beeai_framework/emitter/emitter.py +192 -0
  31. beeai_framework/emitter/errors.py +10 -0
  32. beeai_framework/emitter/types.py +24 -0
  33. beeai_framework/emitter/utils.py +17 -0
  34. beeai_framework/errors.py +179 -0
  35. beeai_framework/llms/__init__.py +27 -0
  36. beeai_framework/llms/base_output.py +27 -0
  37. beeai_framework/llms/llm.py +133 -0
  38. beeai_framework/llms/output.py +29 -0
  39. beeai_framework/memory/__init__.py +34 -0
  40. beeai_framework/memory/base_cache.py +109 -0
  41. beeai_framework/memory/base_memory.py +83 -0
  42. beeai_framework/memory/errors.py +32 -0
  43. beeai_framework/memory/file_cache.py +238 -0
  44. beeai_framework/memory/readonly_memory.py +34 -0
  45. beeai_framework/memory/serializable.py +97 -0
  46. beeai_framework/memory/serializer.py +254 -0
  47. beeai_framework/memory/sliding_cache.py +111 -0
  48. beeai_framework/memory/sliding_memory.py +129 -0
  49. beeai_framework/memory/summarize_memory.py +76 -0
  50. beeai_framework/memory/task_map.py +144 -0
  51. beeai_framework/memory/token_memory.py +123 -0
  52. beeai_framework/memory/unconstrained_cache.py +161 -0
  53. beeai_framework/memory/unconstrained_memory.py +37 -0
  54. beeai_framework/parsers/line_prefix.py +55 -0
  55. beeai_framework/tools/__init__.py +18 -0
  56. beeai_framework/tools/errors.py +13 -0
  57. beeai_framework/tools/mcp_tools.py +80 -0
  58. beeai_framework/tools/search/__init__.py +20 -0
  59. beeai_framework/tools/search/base.py +42 -0
  60. beeai_framework/tools/search/duckduckgo.py +56 -0
  61. beeai_framework/tools/search/wikipedia.py +15 -0
  62. beeai_framework/tools/tool.py +133 -0
  63. beeai_framework/tools/weather/openmeteo.py +124 -0
  64. beeai_framework/utils/__init__.py +8 -0
  65. beeai_framework/utils/_types.py +21 -0
  66. beeai_framework/utils/config.py +18 -0
  67. beeai_framework/utils/counter.py +37 -0
  68. beeai_framework/utils/custom_logger.py +107 -0
  69. beeai_framework/utils/errors.py +17 -0
  70. beeai_framework/utils/events.py +11 -0
  71. beeai_framework/utils/models.py +34 -0
  72. beeai_framework/utils/regex.py +11 -0
  73. beeai_framework/utils/templates.py +41 -0
  74. beeai_framework/workflows/__init__.py +18 -0
  75. beeai_framework/workflows/agent.py +110 -0
  76. beeai_framework/workflows/errors.py +8 -0
  77. beeai_framework/workflows/workflow.py +159 -0
  78. beeai_framework-0.1.0.dist-info/LICENSE +201 -0
  79. beeai_framework-0.1.0.dist-info/METADATA +172 -0
  80. beeai_framework-0.1.0.dist-info/RECORD +82 -0
  81. beeai_framework-0.1.0.dist-info/WHEEL +4 -0
  82. beeai_framework-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,42 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from beeai_framework.agents import BaseAgent
3
+ from beeai_framework.agents.bee.agent import BeeAgent
4
+ from beeai_framework.backend import (
5
+ AssistantMessage,
6
+ CustomMessage,
7
+ Message,
8
+ Role,
9
+ SystemMessage,
10
+ ToolMessage,
11
+ UserMessage,
12
+ )
13
+ from beeai_framework.llms import LLM, AgentInput, BaseLLM
14
+ from beeai_framework.memory import BaseMemory, ReadOnlyMemory, TokenMemory, UnconstrainedMemory
15
+ from beeai_framework.memory.serializable import Serializable
16
+ from beeai_framework.tools import Tool, tool
17
+ from beeai_framework.tools.weather.openmeteo import OpenMeteoTool
18
+ from beeai_framework.utils.templates import Prompt
19
+
20
+ __all__ = [
21
+ "LLM",
22
+ "AgentInput",
23
+ "AssistantMessage",
24
+ "BaseAgent",
25
+ "BaseLLM",
26
+ "BaseMemory",
27
+ "BeeAgent",
28
+ "CustomMessage",
29
+ "Message",
30
+ "OpenMeteoTool",
31
+ "Prompt",
32
+ "ReadOnlyMemory",
33
+ "Role",
34
+ "Serializable",
35
+ "SystemMessage",
36
+ "TokenMemory",
37
+ "Tool",
38
+ "ToolMessage",
39
+ "UnconstrainedMemory",
40
+ "UserMessage",
41
+ "tool",
42
+ ]
@@ -0,0 +1 @@
1
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,158 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ from collections.abc import AsyncGenerator
5
+ from typing import Any
6
+
7
+ import litellm
8
+ from litellm import (
9
+ ModelResponse,
10
+ ModelResponseStream,
11
+ acompletion,
12
+ get_supported_openai_params,
13
+ )
14
+ from pydantic import BaseModel, ConfigDict
15
+
16
+ from beeai_framework.backend.chat import (
17
+ ChatModel,
18
+ ChatModelInput,
19
+ ChatModelOutput,
20
+ ChatModelStructureInput,
21
+ ChatModelStructureOutput,
22
+ )
23
+ from beeai_framework.backend.errors import ChatModelError
24
+ from beeai_framework.backend.message import AssistantMessage, Message, Role, ToolMessage
25
+ from beeai_framework.backend.utils import parse_broken_json
26
+ from beeai_framework.context import RunContext
27
+ from beeai_framework.tools.tool import Tool
28
+ from beeai_framework.utils.custom_logger import BeeLogger
29
+
30
+ logger = BeeLogger(__name__)
31
+
32
+
33
+ class LiteLLMParameters(BaseModel):
34
+ model: str
35
+ messages: list[dict[str, Any]]
36
+ tools: list[dict[str, Any]] | None = None
37
+ response_format: dict[str, Any] | type[BaseModel] | None = None
38
+
39
+ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
40
+
41
+
42
+ class LiteLLMChatModel(ChatModel):
43
+ @property
44
+ def model_id(self) -> str:
45
+ return self._model_id
46
+
47
+ @property
48
+ def provider_id(self) -> str:
49
+ return self._provider_id
50
+
51
+ def __init__(self, model_id: str | None = None, **settings: Any) -> None:
52
+ llm_provider = "ollama_chat" if self.provider_id == "ollama" else self.provider_id
53
+ self.supported_params = get_supported_openai_params(model=self.model_id, custom_llm_provider=llm_provider)
54
+ # drop any unsupported parameters that were passed in
55
+ litellm.drop_params = True
56
+ super().__init__()
57
+
58
+ async def _create(
59
+ self,
60
+ input: ChatModelInput,
61
+ run: RunContext,
62
+ ) -> ChatModelOutput:
63
+ litellm_input = self._transform_input(input)
64
+ response = await acompletion(**litellm_input.model_dump())
65
+ response_message = response.get("choices", [{}])[0].get("message", {})
66
+ response_content = response_message.get("content", "")
67
+ tool_calls = response_message.tool_calls
68
+
69
+ if tool_calls:
70
+ litellm_input.messages.append({"role": Role.ASSISTANT, "content": response_content})
71
+ for tool_call in tool_calls:
72
+ function_name = tool_call.function.name
73
+ function_to_call: Tool = next(filter(lambda t: t.name == function_name, input.tools))
74
+
75
+ function_args = json.loads(tool_call.function.arguments)
76
+ function_response = function_to_call.run(input=function_args)
77
+ litellm_input.messages.append({"role": Role.TOOL, "content": function_response})
78
+
79
+ response = await acompletion(**litellm_input.model_dump())
80
+
81
+ response_output = self._transform_output(response)
82
+ logger.trace(f"Inference response output:\n{response_output}")
83
+ return response_output
84
+
85
+ async def _create_stream(self, input: ChatModelInput, _: RunContext) -> AsyncGenerator[ChatModelOutput]:
86
+ litellm_input = self._transform_input(input)
87
+ parameters = litellm_input.model_dump()
88
+ parameters["stream"] = True
89
+ response = await acompletion(**parameters)
90
+
91
+ # TODO: handle tool calling for streaming
92
+ async for chunk in response:
93
+ response_output = self._transform_output(chunk)
94
+ if not response_output:
95
+ continue
96
+ yield response_output
97
+
98
+ async def _create_structure(self, input: ChatModelStructureInput, run: RunContext) -> ChatModelStructureOutput:
99
+ if "response_format" not in self.supported_params:
100
+ logger.warning(f"{self.provider_id} model {self.model_id} does not support structured data.")
101
+ return await super()._create_structure(input, run)
102
+ else:
103
+ response = await self._create(
104
+ ChatModelInput(messages=input.messages, response_format=input.schema, abort_signal=input.abort_signal),
105
+ run,
106
+ )
107
+
108
+ logger.trace(f"Structured response received:\n{response}")
109
+
110
+ text_response = response.get_text_content()
111
+ result = parse_broken_json(text_response)
112
+ # TODO: validate result matches expected schema
113
+ return ChatModelStructureOutput(object=result)
114
+
115
+ def _get_model_name(self) -> str:
116
+ return f"{'ollama_chat' if self.provider_id == 'ollama' else self.provider_id}/{self.model_id}"
117
+
118
+ def _transform_input(self, input: ChatModelInput) -> LiteLLMParameters:
119
+ messages_list = [message.to_plain() for message in input.messages]
120
+
121
+ if input.tools:
122
+ prepared_tools_list = [{"type": "function", "function": tool.prompt_data()} for tool in input.tools]
123
+ else:
124
+ prepared_tools_list = None
125
+
126
+ model = self._get_model_name()
127
+
128
+ return LiteLLMParameters(
129
+ model=model,
130
+ messages=messages_list,
131
+ tools=prepared_tools_list,
132
+ response_format=input.response_format,
133
+ **self.settings,
134
+ )
135
+
136
+ def _transform_output(self, chunk: ModelResponse | ModelResponseStream) -> ChatModelOutput:
137
+ choice = chunk.get("choices", [{}])[0]
138
+ finish_reason = choice.get("finish_reason")
139
+ message: Message | None = None
140
+ usage = choice.get("usage")
141
+
142
+ if isinstance(chunk, ModelResponseStream):
143
+ if finish_reason:
144
+ return None
145
+ content = choice.get("delta", {}).get("content")
146
+ if choice.get("tool_calls"):
147
+ message = ToolMessage(content)
148
+ elif choice.get("delta"):
149
+ message = AssistantMessage(content)
150
+ else:
151
+ # TODO: handle other possible types
152
+ raise ChatModelError(f"Unhandled event: {choice}")
153
+ else:
154
+ response_message = choice.get("message")
155
+ content = response_message.get("content")
156
+ message = AssistantMessage(content)
157
+
158
+ return ChatModelOutput(messages=[message], finish_reason=finish_reason, usage=usage)
@@ -0,0 +1 @@
1
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1 @@
1
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,19 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ from beeai_framework.adapters.litellm.chat import LiteLLMChatModel
7
+ from beeai_framework.backend.constants import ProviderName
8
+ from beeai_framework.utils.custom_logger import BeeLogger
9
+
10
+ logger = BeeLogger(__name__)
11
+
12
+
13
+ class OllamaChatModel(LiteLLMChatModel):
14
+ provider_id: ProviderName = "ollama"
15
+
16
+ def __init__(self, model_id: str | None = None, **settings: Any) -> None:
17
+ self._model_id = model_id if model_id else os.getenv("OLLAMA_CHAT_MODEL", "llama3.1:8b")
18
+ self.settings = {"base_url": "http://localhost:11434"} | settings
19
+ super().__init__()
@@ -0,0 +1 @@
1
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1 @@
1
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,19 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ from beeai_framework.adapters.litellm.chat import LiteLLMChatModel
7
+ from beeai_framework.backend.constants import ProviderName
8
+ from beeai_framework.utils.custom_logger import BeeLogger
9
+
10
+ logger = BeeLogger(__name__)
11
+
12
+
13
+ class WatsonxChatModel(LiteLLMChatModel):
14
+ provider_id: ProviderName = "watsonx"
15
+
16
+ def __init__(self, model_id: str | None = None, **settings: Any) -> None:
17
+ self._model_id = model_id if model_id else os.getenv("WATSONX_CHAT_MODEL", "ibm/granite-3-8b-instruct")
18
+ self.settings = settings
19
+ super().__init__()
@@ -0,0 +1,6 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from beeai_framework.agents.base import BaseAgent
4
+ from beeai_framework.agents.errors import AgentError
5
+
6
+ __all__ = ["AgentError", "BaseAgent"]
@@ -0,0 +1,60 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from beeai_framework.agents.types import AgentMeta, BeeRunInput, BeeRunOptions, BeeRunOutput
6
+ from beeai_framework.context import Run, RunContext, RunContextInput, RunInstance
7
+ from beeai_framework.emitter import Emitter
8
+ from beeai_framework.memory import BaseMemory
9
+ from beeai_framework.utils.models import ModelLike, to_model, to_model_optional
10
+
11
+
12
+ class BaseAgent(ABC):
13
+ is_running: bool = False
14
+ emitter: Emitter = None
15
+
16
+ def run(self, run_input: ModelLike[BeeRunInput], options: ModelLike[BeeRunOptions] | None = None) -> Run:
17
+ run_input = to_model(BeeRunInput, run_input)
18
+ options = to_model_optional(BeeRunOptions, options)
19
+
20
+ if self.is_running:
21
+ raise RuntimeError("Agent is already running!")
22
+
23
+ try:
24
+ return RunContext.enter(
25
+ RunInstance(emitter=self.emitter),
26
+ RunContextInput(signal=options.signal if options else None, params=(run_input, options)),
27
+ lambda context: self._run(run_input, options, context),
28
+ )
29
+ except Exception as e:
30
+ if isinstance(e, RuntimeError):
31
+ raise e
32
+ else:
33
+ raise RuntimeError("Error has occurred!") from e
34
+ finally:
35
+ self.is_running = False
36
+
37
+ @abstractmethod
38
+ async def _run(self, run_input: BeeRunInput, options: BeeRunOptions | None, context: RunContext) -> BeeRunOutput:
39
+ pass
40
+
41
+ def destroy(self) -> None:
42
+ self.emitter.destroy()
43
+
44
+ @property
45
+ @abstractmethod
46
+ def memory(self) -> BaseMemory:
47
+ pass
48
+
49
+ @memory.setter
50
+ @abstractmethod
51
+ def memory(self, memory: BaseMemory) -> None:
52
+ pass
53
+
54
+ @property
55
+ def meta(self) -> AgentMeta:
56
+ return AgentMeta(
57
+ name=self.__class__.__name__,
58
+ description="",
59
+ tools=[],
60
+ )
@@ -0,0 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from beeai_framework.agents.bee.agent import BeeAgent
4
+
5
+ __all__ = ["BeeAgent"]
@@ -0,0 +1,165 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from collections.abc import Callable
4
+ from datetime import UTC, datetime
5
+
6
+ from beeai_framework.agents.base import BaseAgent
7
+ from beeai_framework.agents.runners.base import (
8
+ BaseRunner,
9
+ BeeRunnerToolInput,
10
+ BeeRunnerToolResult,
11
+ RunnerIteration,
12
+ )
13
+ from beeai_framework.agents.runners.default.runner import DefaultRunner
14
+ from beeai_framework.agents.runners.granite.runner import GraniteRunner
15
+ from beeai_framework.agents.types import (
16
+ AgentMeta,
17
+ BeeAgentExecutionConfig,
18
+ BeeInput,
19
+ BeeRunInput,
20
+ BeeRunOptions,
21
+ BeeRunOutput,
22
+ )
23
+ from beeai_framework.backend import Message
24
+ from beeai_framework.backend.message import AssistantMessage, MessageMeta, UserMessage
25
+ from beeai_framework.context import RunContext
26
+ from beeai_framework.emitter import Emitter, EmitterInput
27
+ from beeai_framework.memory import BaseMemory
28
+
29
+
30
+ class BeeAgent(BaseAgent):
31
+ runner: Callable[..., BaseRunner]
32
+
33
+ def __init__(self, bee_input: BeeInput) -> None:
34
+ self.input = bee_input
35
+ if "granite" in self.input.llm.model_id:
36
+ self.runner = GraniteRunner
37
+ else:
38
+ self.runner = DefaultRunner
39
+ self.emitter = Emitter.root().child(
40
+ EmitterInput(
41
+ namespace=["agent", "bee"],
42
+ creator=self,
43
+ )
44
+ )
45
+
46
+ @property
47
+ def memory(self) -> BaseMemory:
48
+ return self.input.memory
49
+
50
+ @memory.setter
51
+ def memory(self, memory: BaseMemory) -> None:
52
+ self.input.memory = memory
53
+
54
+ @property
55
+ def meta(self) -> AgentMeta:
56
+ tools = self.input.tools[:]
57
+
58
+ if self.input.meta:
59
+ return AgentMeta(
60
+ name=self.input.meta.name,
61
+ description=self.input.meta.description,
62
+ extra_description=self.input.meta.extra_description,
63
+ tools=tools,
64
+ )
65
+
66
+ extra_description = ["Tools that I can use to accomplish given task."]
67
+ for tool in tools:
68
+ extra_description.append(f"Tool ${tool.name}': ${tool.description}.")
69
+
70
+ return AgentMeta(
71
+ name="BeeAI",
72
+ tools=tools,
73
+ description="The BeeAI framework demonstrates its ability to auto-correct and adapt in real-time, improving"
74
+ " the overall reliability and resilience of the system.",
75
+ extra_description="\n".join(extra_description) if len(tools) > 0 else None,
76
+ )
77
+
78
+ async def _run(self, run_input: BeeRunInput, options: BeeRunOptions | None, context: RunContext) -> BeeRunOutput:
79
+ runner = self.runner(
80
+ self.input,
81
+ (
82
+ options
83
+ if options
84
+ else BeeRunOptions(
85
+ execution=self.input.execution
86
+ or (options.execution if options is not None else None)
87
+ or BeeAgentExecutionConfig(
88
+ max_retries_per_step=3,
89
+ total_max_retries=20,
90
+ max_iterations=10,
91
+ ),
92
+ signal=None,
93
+ )
94
+ ),
95
+ context,
96
+ )
97
+ await runner.init(run_input)
98
+
99
+ final_message: Message | None = None
100
+ while not final_message:
101
+ iteration: RunnerIteration = await runner.create_iteration()
102
+
103
+ if iteration.state.tool_name and iteration.state.tool_input:
104
+ tool_result: BeeRunnerToolResult = await runner.tool(
105
+ input=BeeRunnerToolInput(
106
+ state=iteration.state,
107
+ emitter=iteration.emitter,
108
+ meta=iteration.meta,
109
+ signal=iteration.signal,
110
+ )
111
+ )
112
+ await runner.memory.add(
113
+ AssistantMessage(
114
+ content=runner.templates.assistant.render(
115
+ {
116
+ "thought": iteration.state.thought,
117
+ "tool_name": iteration.state.tool_name,
118
+ "tool_input": iteration.state.tool_input,
119
+ "tool_output": tool_result.output.to_string(),
120
+ "final_answer": iteration.state.final_answer,
121
+ }
122
+ ),
123
+ meta=MessageMeta({"success": tool_result.success}),
124
+ )
125
+ )
126
+ iteration.state.tool_output = tool_result.output.get_text_content()
127
+
128
+ for key in ["partialUpdate", "update"]:
129
+ await iteration.emitter.emit(
130
+ key,
131
+ {
132
+ "data": iteration.state,
133
+ "update": {
134
+ "key": "tool_output",
135
+ "value": tool_result.output,
136
+ "parsedValue": tool_result.output.to_string(),
137
+ },
138
+ "meta": {"success": tool_result.success}, # TODO deleted meta
139
+ "memory": runner.memory,
140
+ },
141
+ )
142
+
143
+ if iteration.state.final_answer:
144
+ final_message = AssistantMessage(
145
+ content=iteration.state.final_answer, meta=MessageMeta({"createdAt": datetime.now(tz=UTC)})
146
+ )
147
+ await runner.memory.add(final_message)
148
+ await iteration.emitter.emit(
149
+ "success",
150
+ {
151
+ "data": final_message,
152
+ "iterations": runner.iterations,
153
+ "memory": runner.memory,
154
+ "meta": iteration.meta,
155
+ },
156
+ )
157
+
158
+ if run_input.prompt is not None:
159
+ await self.input.memory.add(
160
+ UserMessage(content=run_input.prompt, meta=MessageMeta({"createdAt": context.created_at}))
161
+ )
162
+
163
+ await self.input.memory.add(final_message)
164
+
165
+ return BeeRunOutput(result=final_message, iterations=runner.iterations, memory=runner.memory)
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from beeai_framework.errors import FrameworkError
4
+
5
+
6
+ class AgentError(FrameworkError):
7
+ """Raised for errors caused by agents."""
8
+
9
+ def __init__(self, message: str = "Agent error", *, cause: Exception | None = None) -> None:
10
+ super().__init__(message, is_fatal=True, is_retryable=False, cause=cause)
@@ -0,0 +1,130 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+
6
+ from beeai_framework.agents.types import (
7
+ BeeAgentRunIteration,
8
+ BeeAgentTemplates,
9
+ BeeInput,
10
+ BeeIterationResult,
11
+ BeeMeta,
12
+ BeeRunInput,
13
+ BeeRunOptions,
14
+ )
15
+ from beeai_framework.cancellation import AbortSignal
16
+ from beeai_framework.context import RunContext
17
+ from beeai_framework.emitter.emitter import Emitter
18
+ from beeai_framework.emitter.types import EmitterInput
19
+ from beeai_framework.memory.base_memory import BaseMemory
20
+ from beeai_framework.tools import ToolOutput
21
+ from beeai_framework.utils.counter import RetryCounter
22
+
23
+
24
+ @dataclass
25
+ class BeeRunnerLLMInput:
26
+ meta: BeeMeta
27
+ signal: AbortSignal
28
+ emitter: Emitter
29
+
30
+
31
+ @dataclass
32
+ class RunnerIteration:
33
+ emitter: Emitter
34
+ state: BeeIterationResult
35
+ meta: BeeMeta
36
+ signal: AbortSignal
37
+
38
+
39
+ @dataclass
40
+ class BeeRunnerToolResult:
41
+ output: ToolOutput
42
+ success: bool
43
+
44
+
45
+ @dataclass
46
+ class BeeRunnerToolInput:
47
+ state: BeeIterationResult # TODO BeeIterationToolResult
48
+ meta: BeeMeta
49
+ signal: AbortSignal
50
+ emitter: Emitter
51
+
52
+
53
+ class BaseRunner(ABC):
54
+ def __init__(self, input: BeeInput, options: BeeRunOptions, run: RunContext) -> None:
55
+ self._input = input
56
+ self._options = options
57
+ self._failed_attempts_counter = RetryCounter(
58
+ max_retries=(
59
+ options.execution.max_iterations if options.execution and options.execution.max_iterations else 0
60
+ ),
61
+ error_type=Exception, # TODO Specific error type
62
+ )
63
+
64
+ self._memory: BaseMemory | None = None
65
+
66
+ self._iterations: list[BeeAgentRunIteration] = []
67
+ self._failedAttemptsCounter: RetryCounter = RetryCounter(
68
+ error_type=Exception, # TODO AgentError
69
+ max_retries=(
70
+ options.execution.total_max_retries if options.execution and options.execution.total_max_retries else 0
71
+ ),
72
+ )
73
+ self._run = run
74
+
75
+ @property
76
+ def iterations(self) -> list[BeeAgentRunIteration]:
77
+ return self._iterations
78
+
79
+ @property
80
+ def memory(self) -> BaseMemory:
81
+ if self._memory is not None:
82
+ return self._memory
83
+ raise Exception("Memory has not been initialized.")
84
+
85
+ async def create_iteration(self) -> RunnerIteration:
86
+ meta: BeeMeta = BeeMeta(iteration=len(self._iterations) + 1)
87
+ max_iterations = (
88
+ self._options.execution.max_iterations
89
+ if self._options.execution and self._options.execution.max_iterations
90
+ else 0
91
+ )
92
+
93
+ if meta.iteration > max_iterations:
94
+ # TODO: Raise Agent Error with metadata
95
+ # https://github.com/i-am-bee/beeai-framework/blob/aa4d5e6091ed3bab8096492707ceb03d3b03863b/src/agents/bee/runners/base.ts#L70
96
+ raise Exception(f"Agent was not able to resolve the task in {max_iterations} iterations.")
97
+
98
+ emitter = self._run.emitter.child(emitter_input=EmitterInput(group_id=f"`iteration-${meta.iteration}"))
99
+ iteration: BeeAgentRunIteration = await self.llm(
100
+ BeeRunnerLLMInput(emitter=emitter, signal=self._run.signal, meta=meta)
101
+ )
102
+ self._iterations.append(iteration)
103
+
104
+ return RunnerIteration(emitter=emitter, state=iteration.state, meta=meta, signal=self._run.signal)
105
+
106
+ async def init(self, input: BeeRunInput) -> None:
107
+ self._memory = await self.init_memory(input)
108
+
109
+ @abstractmethod
110
+ async def llm(self, input: BeeRunnerLLMInput) -> BeeAgentRunIteration:
111
+ pass
112
+
113
+ @abstractmethod
114
+ async def tool(self, input: BeeRunnerToolInput) -> BeeRunnerToolResult:
115
+ pass
116
+
117
+ @abstractmethod
118
+ def default_templates(self) -> BeeAgentTemplates:
119
+ pass
120
+
121
+ @abstractmethod
122
+ async def init_memory(self, input: BeeRunInput) -> BaseMemory:
123
+ pass
124
+
125
+ @property
126
+ def templates(self) -> BeeAgentTemplates:
127
+ # TODO: overrides
128
+ return self.default_templates()
129
+
130
+ # TODO: Serialization