mojentic 0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _examples/async_dispatcher_example.py +241 -0
- _examples/async_llm_example.py +236 -0
- mojentic/agents/async_aggregator_agent.py +162 -0
- mojentic/agents/async_aggregator_agent_spec.py +227 -0
- mojentic/agents/async_llm_agent.py +197 -0
- mojentic/agents/async_llm_agent_spec.py +166 -0
- mojentic/agents/base_async_agent.py +27 -0
- mojentic/async_dispatcher.py +134 -0
- mojentic/async_dispatcher_spec.py +244 -0
- mojentic/llm/gateways/models.py +3 -3
- mojentic/llm/gateways/ollama.py +4 -4
- mojentic/llm/gateways/openai.py +3 -3
- mojentic/llm/gateways/openai_messages_adapter.py +8 -4
- mojentic/llm/llm_broker.py +4 -4
- {mojentic-0.6.2.dist-info → mojentic-0.7.1.dist-info}/METADATA +2 -1
- {mojentic-0.6.2.dist-info → mojentic-0.7.1.dist-info}/RECORD +19 -10
- {mojentic-0.6.2.dist-info → mojentic-0.7.1.dist-info}/WHEEL +0 -0
- {mojentic-0.6.2.dist-info → mojentic-0.7.1.dist-info}/licenses/LICENSE.md +0 -0
- {mojentic-0.6.2.dist-info → mojentic-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock
|
|
4
|
+
|
|
5
|
+
from mojentic.agents.async_aggregator_agent import AsyncAggregatorAgent
|
|
6
|
+
from mojentic.event import Event
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestEvent1(Event):
|
|
10
|
+
"""A test event type 1."""
|
|
11
|
+
message: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestEvent2(Event):
|
|
15
|
+
"""A test event type 2."""
|
|
16
|
+
data: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TestEvent3(Event):
|
|
20
|
+
"""A test event type 3."""
|
|
21
|
+
value: int
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestResultEvent(Event):
|
|
25
|
+
"""A test result event."""
|
|
26
|
+
result: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestAsyncAggregator(AsyncAggregatorAgent):
|
|
30
|
+
"""A test implementation of AsyncAggregatorAgent."""
|
|
31
|
+
|
|
32
|
+
async def process_events(self, events):
|
|
33
|
+
"""Process the events and return a result event."""
|
|
34
|
+
# Extract events by type
|
|
35
|
+
event1 = next((e for e in events if isinstance(e, TestEvent1)), None)
|
|
36
|
+
event2 = next((e for e in events if isinstance(e, TestEvent2)), None)
|
|
37
|
+
|
|
38
|
+
if event1 and event2:
|
|
39
|
+
# Create a result combining the events
|
|
40
|
+
return [TestResultEvent(
|
|
41
|
+
source=type(self),
|
|
42
|
+
correlation_id=event1.correlation_id,
|
|
43
|
+
result=f"{event1.message} - {event2.data}"
|
|
44
|
+
)]
|
|
45
|
+
return []
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
def async_aggregator():
|
|
50
|
+
"""Create an AsyncAggregatorAgent for testing."""
|
|
51
|
+
return AsyncAggregatorAgent(event_types_needed=[TestEvent1, TestEvent2])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def test_async_aggregator():
|
|
56
|
+
"""Create a TestAsyncAggregator for testing."""
|
|
57
|
+
return TestAsyncAggregator(event_types_needed=[TestEvent1, TestEvent2])
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.mark.asyncio
|
|
61
|
+
async def test_async_aggregator_init():
|
|
62
|
+
"""Test that the AsyncAggregatorAgent initializes correctly."""
|
|
63
|
+
agent = AsyncAggregatorAgent(event_types_needed=[TestEvent1, TestEvent2])
|
|
64
|
+
|
|
65
|
+
assert agent.event_types_needed == [TestEvent1, TestEvent2]
|
|
66
|
+
assert agent.results == {}
|
|
67
|
+
assert agent.futures == {}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_async_aggregator_capture_results(async_aggregator):
|
|
72
|
+
"""Test that the AsyncAggregatorAgent captures results correctly."""
|
|
73
|
+
event = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
74
|
+
|
|
75
|
+
await async_aggregator._capture_results_if_needed(event)
|
|
76
|
+
|
|
77
|
+
assert "test-id" in async_aggregator.results
|
|
78
|
+
assert len(async_aggregator.results["test-id"]) == 1
|
|
79
|
+
assert async_aggregator.results["test-id"][0] == event
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.asyncio
|
|
83
|
+
async def test_async_aggregator_has_all_needed(async_aggregator):
|
|
84
|
+
"""Test that the AsyncAggregatorAgent checks if all needed events are captured."""
|
|
85
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
86
|
+
event2 = TestEvent2(source=str, correlation_id="test-id", data="World")
|
|
87
|
+
|
|
88
|
+
# Initially, we don't have all needed events
|
|
89
|
+
assert not await async_aggregator._has_all_needed(event1)
|
|
90
|
+
|
|
91
|
+
# Capture the first event
|
|
92
|
+
await async_aggregator._capture_results_if_needed(event1)
|
|
93
|
+
assert not await async_aggregator._has_all_needed(event1)
|
|
94
|
+
|
|
95
|
+
# Capture the second event
|
|
96
|
+
await async_aggregator._capture_results_if_needed(event2)
|
|
97
|
+
assert await async_aggregator._has_all_needed(event2)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@pytest.mark.asyncio
|
|
101
|
+
async def test_async_aggregator_get_and_reset_results(async_aggregator):
|
|
102
|
+
"""Test that the AsyncAggregatorAgent gets and resets results correctly."""
|
|
103
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
104
|
+
event2 = TestEvent2(source=str, correlation_id="test-id", data="World")
|
|
105
|
+
|
|
106
|
+
# Capture the events
|
|
107
|
+
await async_aggregator._capture_results_if_needed(event1)
|
|
108
|
+
await async_aggregator._capture_results_if_needed(event2)
|
|
109
|
+
|
|
110
|
+
# Get and reset the results
|
|
111
|
+
results = await async_aggregator._get_and_reset_results(event1)
|
|
112
|
+
|
|
113
|
+
assert len(results) == 2
|
|
114
|
+
assert results[0] == event1
|
|
115
|
+
assert results[1] == event2
|
|
116
|
+
assert async_aggregator.results["test-id"] is None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@pytest.mark.asyncio
|
|
120
|
+
async def test_async_aggregator_wait_for_events(async_aggregator):
|
|
121
|
+
"""Test that the AsyncAggregatorAgent waits for events correctly."""
|
|
122
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
123
|
+
event2 = TestEvent2(source=str, correlation_id="test-id", data="World")
|
|
124
|
+
|
|
125
|
+
# Start waiting for events in a separate task
|
|
126
|
+
wait_task = asyncio.create_task(async_aggregator.wait_for_events("test-id", timeout=1))
|
|
127
|
+
|
|
128
|
+
# Capture the events
|
|
129
|
+
await async_aggregator._capture_results_if_needed(event1)
|
|
130
|
+
await async_aggregator._capture_results_if_needed(event2)
|
|
131
|
+
|
|
132
|
+
# Wait for the task to complete
|
|
133
|
+
results = await wait_task
|
|
134
|
+
|
|
135
|
+
assert len(results) == 2
|
|
136
|
+
assert results[0] == event1
|
|
137
|
+
assert results[1] == event2
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@pytest.mark.asyncio
|
|
141
|
+
async def test_async_aggregator_wait_for_events_timeout(async_aggregator):
|
|
142
|
+
"""Test that the AsyncAggregatorAgent handles timeouts correctly."""
|
|
143
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
144
|
+
|
|
145
|
+
# Capture only one event
|
|
146
|
+
await async_aggregator._capture_results_if_needed(event1)
|
|
147
|
+
|
|
148
|
+
# Wait for events with a short timeout
|
|
149
|
+
results = await async_aggregator.wait_for_events("test-id", timeout=0.1)
|
|
150
|
+
|
|
151
|
+
# We should get the partial results
|
|
152
|
+
assert len(results) == 1
|
|
153
|
+
assert results[0] == event1
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_async_aggregator_receive_event_async(test_async_aggregator):
|
|
158
|
+
"""Test that the AsyncAggregatorAgent processes events correctly."""
|
|
159
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
160
|
+
event2 = TestEvent2(source=str, correlation_id="test-id", data="World")
|
|
161
|
+
|
|
162
|
+
# Receive the first event - should not process yet
|
|
163
|
+
result1 = await test_async_aggregator.receive_event_async(event1)
|
|
164
|
+
assert result1 == []
|
|
165
|
+
|
|
166
|
+
# Receive the second event - should process both events
|
|
167
|
+
result2 = await test_async_aggregator.receive_event_async(event2)
|
|
168
|
+
|
|
169
|
+
assert len(result2) == 1
|
|
170
|
+
assert isinstance(result2[0], TestResultEvent)
|
|
171
|
+
assert result2[0].result == "Hello - World"
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@pytest.mark.asyncio
|
|
175
|
+
async def test_async_aggregator_receive_event_async_wrong_order(test_async_aggregator):
|
|
176
|
+
"""Test that the AsyncAggregatorAgent processes events correctly regardless of order."""
|
|
177
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
178
|
+
event2 = TestEvent2(source=str, correlation_id="test-id", data="World")
|
|
179
|
+
|
|
180
|
+
# Receive the second event first - should not process yet
|
|
181
|
+
result1 = await test_async_aggregator.receive_event_async(event2)
|
|
182
|
+
assert result1 == []
|
|
183
|
+
|
|
184
|
+
# Receive the first event - should process both events
|
|
185
|
+
result2 = await test_async_aggregator.receive_event_async(event1)
|
|
186
|
+
|
|
187
|
+
assert len(result2) == 1
|
|
188
|
+
assert isinstance(result2[0], TestResultEvent)
|
|
189
|
+
assert result2[0].result == "Hello - World"
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@pytest.mark.asyncio
|
|
193
|
+
async def test_async_aggregator_receive_event_async_different_correlation_ids(test_async_aggregator):
|
|
194
|
+
"""Test that the AsyncAggregatorAgent handles different correlation_ids correctly."""
|
|
195
|
+
event1_id1 = TestEvent1(source=str, correlation_id="id1", message="Hello")
|
|
196
|
+
event2_id1 = TestEvent2(source=str, correlation_id="id1", data="World")
|
|
197
|
+
|
|
198
|
+
event1_id2 = TestEvent1(source=str, correlation_id="id2", message="Goodbye")
|
|
199
|
+
event2_id2 = TestEvent2(source=str, correlation_id="id2", data="Universe")
|
|
200
|
+
|
|
201
|
+
# Receive events for id1
|
|
202
|
+
await test_async_aggregator.receive_event_async(event1_id1)
|
|
203
|
+
result1 = await test_async_aggregator.receive_event_async(event2_id1)
|
|
204
|
+
|
|
205
|
+
# Receive events for id2
|
|
206
|
+
await test_async_aggregator.receive_event_async(event1_id2)
|
|
207
|
+
result2 = await test_async_aggregator.receive_event_async(event2_id2)
|
|
208
|
+
|
|
209
|
+
# Check results for id1
|
|
210
|
+
assert len(result1) == 1
|
|
211
|
+
assert result1[0].result == "Hello - World"
|
|
212
|
+
|
|
213
|
+
# Check results for id2
|
|
214
|
+
assert len(result2) == 1
|
|
215
|
+
assert result2[0].result == "Goodbye - Universe"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.mark.asyncio
|
|
219
|
+
async def test_async_aggregator_process_events_base_implementation(async_aggregator):
|
|
220
|
+
"""Test that the base process_events implementation returns an empty list."""
|
|
221
|
+
event1 = TestEvent1(source=str, correlation_id="test-id", message="Hello")
|
|
222
|
+
event2 = TestEvent2(source=str, correlation_id="test-id", data="World")
|
|
223
|
+
|
|
224
|
+
events = [event1, event2]
|
|
225
|
+
result = await async_aggregator.process_events(events)
|
|
226
|
+
|
|
227
|
+
assert result == []
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Annotated, List, Optional, Type
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from mojentic.agents.base_async_agent import BaseAsyncAgent
|
|
7
|
+
from mojentic.context.shared_working_memory import SharedWorkingMemory
|
|
8
|
+
from mojentic.event import Event
|
|
9
|
+
from mojentic.llm.gateways.models import LLMMessage, MessageRole
|
|
10
|
+
from mojentic.llm.llm_broker import LLMBroker
|
|
11
|
+
from mojentic.llm.tools.llm_tool import LLMTool
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseAsyncLLMAgent(BaseAsyncAgent):
|
|
15
|
+
"""
|
|
16
|
+
BaseAsyncLLMAgent is an asynchronous version of the BaseLLMAgent.
|
|
17
|
+
It uses an LLM to generate responses asynchronously.
|
|
18
|
+
"""
|
|
19
|
+
llm: LLMBroker
|
|
20
|
+
behaviour: Annotated[str, "The personality and behavioural traits of the agent."]
|
|
21
|
+
|
|
22
|
+
def __init__(self, llm: LLMBroker, behaviour: str = "You are a helpful assistant.",
|
|
23
|
+
tools: Optional[List[LLMTool]] = None, response_model: Optional[Type[BaseModel]] = None):
|
|
24
|
+
"""
|
|
25
|
+
Initialize the BaseAsyncLLMAgent.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
llm : LLMBroker
|
|
30
|
+
The LLM broker to use for generating responses
|
|
31
|
+
behaviour : str, optional
|
|
32
|
+
The personality and behavioural traits of the agent
|
|
33
|
+
tools : List[LLMTool], optional
|
|
34
|
+
The tools available to the agent
|
|
35
|
+
response_model : Type[BaseModel], optional
|
|
36
|
+
The model to use for responses
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.llm = llm
|
|
40
|
+
self.behaviour = behaviour
|
|
41
|
+
self.response_model = response_model
|
|
42
|
+
self.tools = tools or []
|
|
43
|
+
|
|
44
|
+
def _create_initial_messages(self):
|
|
45
|
+
"""
|
|
46
|
+
Create the initial messages for the LLM.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
list
|
|
51
|
+
The initial messages for the LLM
|
|
52
|
+
"""
|
|
53
|
+
return [
|
|
54
|
+
LLMMessage(role=MessageRole.System, content=self.behaviour),
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
def add_tool(self, tool):
|
|
58
|
+
"""
|
|
59
|
+
Add a tool to the agent.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
tool : LLMTool
|
|
64
|
+
The tool to add
|
|
65
|
+
"""
|
|
66
|
+
self.tools.append(tool)
|
|
67
|
+
|
|
68
|
+
async def generate_response(self, content):
|
|
69
|
+
"""
|
|
70
|
+
Generate a response using the LLM asynchronously.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
content : str
|
|
75
|
+
The content to generate a response for
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
str or BaseModel
|
|
80
|
+
The generated response
|
|
81
|
+
"""
|
|
82
|
+
messages = self._create_initial_messages()
|
|
83
|
+
messages.append(LLMMessage(content=content))
|
|
84
|
+
|
|
85
|
+
if self.response_model is not None:
|
|
86
|
+
# Use asyncio.to_thread to run the synchronous generate_object method in a separate thread
|
|
87
|
+
import asyncio
|
|
88
|
+
response = await asyncio.to_thread(self.llm.generate_object, messages, object_model=self.response_model)
|
|
89
|
+
else:
|
|
90
|
+
# Use asyncio.to_thread to run the synchronous generate method in a separate thread
|
|
91
|
+
import asyncio
|
|
92
|
+
response = await asyncio.to_thread(self.llm.generate, messages, tools=self.tools)
|
|
93
|
+
|
|
94
|
+
return response
|
|
95
|
+
|
|
96
|
+
async def receive_event_async(self, event: Event) -> List[Event]:
|
|
97
|
+
"""
|
|
98
|
+
Receive an event and process it asynchronously.
|
|
99
|
+
This method should be overridden by subclasses.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
event : Event
|
|
104
|
+
The event to process
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
List[Event]
|
|
109
|
+
The events to be processed next
|
|
110
|
+
"""
|
|
111
|
+
return []
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BaseAsyncLLMAgentWithMemory(BaseAsyncLLMAgent):
|
|
115
|
+
"""
|
|
116
|
+
BaseAsyncLLMAgentWithMemory is an asynchronous version of the BaseLLMAgentWithMemory.
|
|
117
|
+
It uses an LLM to generate responses asynchronously and maintains a shared working memory.
|
|
118
|
+
"""
|
|
119
|
+
instructions: Annotated[str, "The instructions for the agent to follow when receiving events."]
|
|
120
|
+
|
|
121
|
+
def __init__(self, llm: LLMBroker, memory: SharedWorkingMemory, behaviour: str, instructions: str,
|
|
122
|
+
response_model: BaseModel):
|
|
123
|
+
"""
|
|
124
|
+
Initialize the BaseAsyncLLMAgentWithMemory.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
llm : LLMBroker
|
|
129
|
+
The LLM broker to use for generating responses
|
|
130
|
+
memory : SharedWorkingMemory
|
|
131
|
+
The shared working memory to use
|
|
132
|
+
behaviour : str
|
|
133
|
+
The personality and behavioural traits of the agent
|
|
134
|
+
instructions : str
|
|
135
|
+
The instructions for the agent to follow when receiving events
|
|
136
|
+
response_model : BaseModel
|
|
137
|
+
The model to use for responses
|
|
138
|
+
"""
|
|
139
|
+
super().__init__(llm, behaviour)
|
|
140
|
+
self.instructions = instructions
|
|
141
|
+
self.memory = memory
|
|
142
|
+
self.response_model = response_model
|
|
143
|
+
|
|
144
|
+
def _create_initial_messages(self):
|
|
145
|
+
"""
|
|
146
|
+
Create the initial messages for the LLM.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
list
|
|
151
|
+
The initial messages for the LLM
|
|
152
|
+
"""
|
|
153
|
+
messages = super()._create_initial_messages()
|
|
154
|
+
messages.extend([
|
|
155
|
+
LLMMessage(content=f"This is what you remember:\n{json.dumps(self.memory.get_working_memory(), indent=2)}"
|
|
156
|
+
f"\n\nRemember anything new you learn by storing it to your working memory in your response."),
|
|
157
|
+
LLMMessage(role=MessageRole.User, content=self.instructions),
|
|
158
|
+
])
|
|
159
|
+
return messages
|
|
160
|
+
|
|
161
|
+
async def generate_response(self, content):
|
|
162
|
+
"""
|
|
163
|
+
Generate a response using the LLM asynchronously.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
content : str
|
|
168
|
+
The content to generate a response for
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
BaseModel
|
|
173
|
+
The generated response
|
|
174
|
+
"""
|
|
175
|
+
class ResponseWithMemory(self.response_model):
|
|
176
|
+
memory: dict = Field(self.memory.get_working_memory(),
|
|
177
|
+
description="Add anything new that you have learned here.")
|
|
178
|
+
|
|
179
|
+
messages = self._create_initial_messages()
|
|
180
|
+
messages.extend([
|
|
181
|
+
LLMMessage(content=content),
|
|
182
|
+
])
|
|
183
|
+
|
|
184
|
+
# Use asyncio.to_thread to run the synchronous generate_object method in a separate thread
|
|
185
|
+
import asyncio
|
|
186
|
+
response = await asyncio.to_thread(
|
|
187
|
+
self.llm.generate_object,
|
|
188
|
+
messages=messages,
|
|
189
|
+
object_model=ResponseWithMemory
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.memory.merge_to_working_memory(response.memory)
|
|
193
|
+
|
|
194
|
+
d = response.model_dump()
|
|
195
|
+
del d["memory"]
|
|
196
|
+
|
|
197
|
+
return self.response_model.model_validate(d)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from mojentic.agents.async_llm_agent import BaseAsyncLLMAgent
|
|
8
|
+
from mojentic.event import Event
|
|
9
|
+
from mojentic.llm.llm_broker import LLMBroker
|
|
10
|
+
from mojentic.llm.gateways.models import LLMMessage, MessageRole
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestEvent(Event):
|
|
14
|
+
"""A simple event for testing."""
|
|
15
|
+
message: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestResponse(BaseModel):
|
|
19
|
+
"""A simple response model for testing."""
|
|
20
|
+
answer: str = Field(..., description="The answer to the question")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def mock_llm_broker():
|
|
25
|
+
"""Create a mock LLM broker for testing."""
|
|
26
|
+
mock_broker = MagicMock(spec=LLMBroker)
|
|
27
|
+
mock_broker.generate.return_value = "Test response"
|
|
28
|
+
mock_broker.generate_object.return_value = TestResponse(answer="Test answer")
|
|
29
|
+
return mock_broker
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def async_llm_agent(mock_llm_broker):
|
|
34
|
+
"""Create a BaseAsyncLLMAgent for testing."""
|
|
35
|
+
return BaseAsyncLLMAgent(
|
|
36
|
+
llm=mock_llm_broker,
|
|
37
|
+
behaviour="You are a test assistant.",
|
|
38
|
+
response_model=TestResponse
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.asyncio
|
|
43
|
+
async def test_async_llm_agent_init(mock_llm_broker):
|
|
44
|
+
"""Test that the BaseAsyncLLMAgent initializes correctly."""
|
|
45
|
+
agent = BaseAsyncLLMAgent(
|
|
46
|
+
llm=mock_llm_broker,
|
|
47
|
+
behaviour="You are a test assistant.",
|
|
48
|
+
response_model=TestResponse
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
assert agent.llm == mock_llm_broker
|
|
52
|
+
assert agent.behaviour == "You are a test assistant."
|
|
53
|
+
assert agent.response_model == TestResponse
|
|
54
|
+
assert agent.tools == []
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.mark.asyncio
|
|
58
|
+
async def test_async_llm_agent_create_initial_messages(async_llm_agent):
|
|
59
|
+
"""Test that the BaseAsyncLLMAgent creates initial messages correctly."""
|
|
60
|
+
messages = async_llm_agent._create_initial_messages()
|
|
61
|
+
|
|
62
|
+
assert len(messages) == 1
|
|
63
|
+
assert messages[0].role == MessageRole.System
|
|
64
|
+
assert messages[0].content == "You are a test assistant."
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.mark.asyncio
|
|
68
|
+
async def test_async_llm_agent_add_tool(async_llm_agent):
|
|
69
|
+
"""Test that the BaseAsyncLLMAgent can add tools."""
|
|
70
|
+
mock_tool = MagicMock()
|
|
71
|
+
async_llm_agent.add_tool(mock_tool)
|
|
72
|
+
|
|
73
|
+
assert mock_tool in async_llm_agent.tools
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@pytest.mark.asyncio
|
|
77
|
+
async def test_async_llm_agent_generate_response_with_model(async_llm_agent, mock_llm_broker):
|
|
78
|
+
"""Test that the BaseAsyncLLMAgent generates responses with a model."""
|
|
79
|
+
response = await async_llm_agent.generate_response("Test question")
|
|
80
|
+
|
|
81
|
+
# Verify that generate_object was called
|
|
82
|
+
mock_llm_broker.generate_object.assert_called_once()
|
|
83
|
+
|
|
84
|
+
# Verify the response
|
|
85
|
+
assert isinstance(response, TestResponse)
|
|
86
|
+
assert response.answer == "Test answer"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.asyncio
|
|
90
|
+
async def test_async_llm_agent_generate_response_without_model(mock_llm_broker):
|
|
91
|
+
"""Test that the BaseAsyncLLMAgent generates responses without a model."""
|
|
92
|
+
agent = BaseAsyncLLMAgent(
|
|
93
|
+
llm=mock_llm_broker,
|
|
94
|
+
behaviour="You are a test assistant."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
response = await agent.generate_response("Test question")
|
|
98
|
+
|
|
99
|
+
# Verify that generate was called
|
|
100
|
+
mock_llm_broker.generate.assert_called_once()
|
|
101
|
+
|
|
102
|
+
# Verify the response
|
|
103
|
+
assert response == "Test response"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.asyncio
|
|
107
|
+
async def test_async_llm_agent_generate_response_with_tools(mock_llm_broker):
|
|
108
|
+
"""Test that the BaseAsyncLLMAgent generates responses with tools."""
|
|
109
|
+
mock_tool = MagicMock()
|
|
110
|
+
|
|
111
|
+
agent = BaseAsyncLLMAgent(
|
|
112
|
+
llm=mock_llm_broker,
|
|
113
|
+
behaviour="You are a test assistant.",
|
|
114
|
+
tools=[mock_tool]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
response = await agent.generate_response("Test question")
|
|
118
|
+
|
|
119
|
+
# Verify that generate was called with tools
|
|
120
|
+
mock_llm_broker.generate.assert_called_once()
|
|
121
|
+
args, kwargs = mock_llm_broker.generate.call_args
|
|
122
|
+
assert kwargs.get('tools') == [mock_tool]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@pytest.mark.asyncio
|
|
126
|
+
async def test_async_llm_agent_receive_event_async(async_llm_agent):
|
|
127
|
+
"""Test that the BaseAsyncLLMAgent's receive_event_async method works."""
|
|
128
|
+
event = TestEvent(source=str, message="Test message")
|
|
129
|
+
|
|
130
|
+
# The base implementation should return an empty list
|
|
131
|
+
result = await async_llm_agent.receive_event_async(event)
|
|
132
|
+
|
|
133
|
+
assert result == []
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Create a subclass for testing the receive_event_async method
|
|
137
|
+
class TestAsyncLLMAgent(BaseAsyncLLMAgent):
|
|
138
|
+
"""A test async LLM agent that implements receive_event_async."""
|
|
139
|
+
|
|
140
|
+
async def receive_event_async(self, event):
|
|
141
|
+
if isinstance(event, TestEvent):
|
|
142
|
+
response = await self.generate_response(event.message)
|
|
143
|
+
return [TestEvent(
|
|
144
|
+
source=type(self),
|
|
145
|
+
correlation_id=event.correlation_id,
|
|
146
|
+
message=f"Response: {response.answer}"
|
|
147
|
+
)]
|
|
148
|
+
return []
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_subclass_async_llm_agent_receive_event_async(mock_llm_broker):
|
|
153
|
+
"""Test that a subclass of BaseAsyncLLMAgent can implement receive_event_async."""
|
|
154
|
+
agent = TestAsyncLLMAgent(
|
|
155
|
+
llm=mock_llm_broker,
|
|
156
|
+
behaviour="You are a test assistant.",
|
|
157
|
+
response_model=TestResponse
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
event = TestEvent(source=str, message="Test message")
|
|
161
|
+
|
|
162
|
+
result = await agent.receive_event_async(event)
|
|
163
|
+
|
|
164
|
+
assert len(result) == 1
|
|
165
|
+
assert isinstance(result[0], TestEvent)
|
|
166
|
+
assert result[0].message == "Response: Test answer"
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from mojentic.event import Event
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseAsyncAgent:
|
|
7
|
+
"""
|
|
8
|
+
BaseAsyncAgent class is the base class for all asynchronous agents.
|
|
9
|
+
It provides an async receive method for event processing.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
async def receive_event_async(self, event: Event) -> List[Event]:
|
|
13
|
+
"""
|
|
14
|
+
receive_event_async is the method that all async agents must implement. It takes an event as input and returns a list of
|
|
15
|
+
events as output.
|
|
16
|
+
|
|
17
|
+
In this way, you can perform work based on the event, and generate whatever subsequent events may need to be
|
|
18
|
+
processed next.
|
|
19
|
+
|
|
20
|
+
This keeps the agent decoupled from the specifics of the event routing and processing.
|
|
21
|
+
|
|
22
|
+
Events are subclasses of the Event class.
|
|
23
|
+
|
|
24
|
+
:param event: The event to process
|
|
25
|
+
:return: A list of events to be processed next
|
|
26
|
+
"""
|
|
27
|
+
return []
|