mojentic 0.8.4__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _examples/async_dispatcher_example.py +12 -4
- _examples/async_llm_example.py +1 -2
- _examples/broker_as_tool.py +42 -17
- _examples/broker_examples.py +5 -7
- _examples/broker_image_examples.py +1 -1
- _examples/characterize_ollama.py +3 -3
- _examples/characterize_openai.py +1 -1
- _examples/chat_session.py +2 -2
- _examples/chat_session_with_tool.py +2 -2
- _examples/coding_file_tool.py +16 -18
- _examples/current_datetime_tool_example.py +2 -2
- _examples/embeddings.py +1 -1
- _examples/ephemeral_task_manager_example.py +15 -11
- _examples/fetch_openai_models.py +10 -3
- _examples/file_deduplication.py +6 -6
- _examples/file_tool.py +5 -5
- _examples/image_analysis.py +2 -3
- _examples/image_broker.py +1 -1
- _examples/image_broker_splat.py +1 -1
- _examples/iterative_solver.py +3 -3
- _examples/model_characterization.py +2 -0
- _examples/openai_gateway_enhanced_demo.py +15 -5
- _examples/raw.py +1 -1
- _examples/react/agents/decisioning_agent.py +173 -15
- _examples/react/agents/summarization_agent.py +89 -0
- _examples/react/agents/thinking_agent.py +84 -14
- _examples/react/agents/tool_call_agent.py +83 -0
- _examples/react/formatters.py +38 -4
- _examples/react/models/base.py +60 -11
- _examples/react/models/events.py +76 -8
- _examples/react.py +71 -21
- _examples/recursive_agent.py +2 -2
- _examples/simple_llm.py +3 -3
- _examples/simple_llm_repl.py +1 -1
- _examples/simple_structured.py +1 -1
- _examples/simple_tool.py +2 -2
- _examples/solver_chat_session.py +5 -11
- _examples/streaming.py +36 -18
- _examples/tell_user_example.py +4 -4
- _examples/tracer_demo.py +18 -20
- _examples/tracer_qt_viewer.py +49 -46
- _examples/working_memory.py +1 -1
- mojentic/__init__.py +3 -3
- mojentic/agents/__init__.py +26 -8
- mojentic/agents/{agent_broker.py → agent_event_adapter.py} +3 -3
- mojentic/agents/async_aggregator_agent_spec.py +32 -33
- mojentic/agents/async_llm_agent.py +9 -5
- mojentic/agents/async_llm_agent_spec.py +21 -22
- mojentic/agents/base_async_agent.py +2 -2
- mojentic/agents/base_llm_agent.py +6 -2
- mojentic/agents/iterative_problem_solver.py +11 -5
- mojentic/agents/simple_recursive_agent.py +11 -10
- mojentic/agents/simple_recursive_agent_spec.py +423 -0
- mojentic/async_dispatcher.py +0 -1
- mojentic/async_dispatcher_spec.py +1 -1
- mojentic/context/__init__.py +0 -2
- mojentic/dispatcher.py +7 -8
- mojentic/llm/__init__.py +5 -5
- mojentic/llm/gateways/__init__.py +19 -18
- mojentic/llm/gateways/anthropic.py +1 -0
- mojentic/llm/gateways/anthropic_messages_adapter.py +0 -1
- mojentic/llm/gateways/llm_gateway.py +1 -1
- mojentic/llm/gateways/ollama.py +23 -18
- mojentic/llm/gateways/openai.py +243 -44
- mojentic/llm/gateways/openai_message_adapter_spec.py +3 -3
- mojentic/llm/gateways/openai_model_registry.py +7 -6
- mojentic/llm/gateways/openai_model_registry_spec.py +1 -2
- mojentic/llm/gateways/openai_temperature_handling_spec.py +2 -2
- mojentic/llm/llm_broker.py +162 -2
- mojentic/llm/llm_broker_spec.py +76 -2
- mojentic/llm/message_composers.py +6 -3
- mojentic/llm/message_composers_spec.py +5 -1
- mojentic/llm/registry/__init__.py +0 -3
- mojentic/llm/registry/populate_registry_from_ollama.py +2 -2
- mojentic/llm/tools/__init__.py +0 -9
- mojentic/llm/tools/ask_user_tool.py +11 -5
- mojentic/llm/tools/current_datetime.py +9 -6
- mojentic/llm/tools/date_resolver.py +10 -4
- mojentic/llm/tools/date_resolver_spec.py +0 -1
- mojentic/llm/tools/ephemeral_task_manager/append_task_tool.py +4 -1
- mojentic/llm/tools/ephemeral_task_manager/ephemeral_task_list.py +1 -1
- mojentic/llm/tools/ephemeral_task_manager/insert_task_after_tool.py +4 -1
- mojentic/llm/tools/ephemeral_task_manager/prepend_task_tool.py +5 -2
- mojentic/llm/tools/file_manager.py +131 -28
- mojentic/llm/tools/file_manager_spec.py +0 -3
- mojentic/llm/tools/llm_tool.py +1 -1
- mojentic/llm/tools/llm_tool_spec.py +0 -2
- mojentic/llm/tools/organic_web_search.py +4 -2
- mojentic/llm/tools/tell_user_tool.py +6 -2
- mojentic/llm/tools/tool_wrapper.py +2 -2
- mojentic/tracer/__init__.py +1 -10
- mojentic/tracer/event_store.py +7 -8
- mojentic/tracer/event_store_spec.py +1 -2
- mojentic/tracer/null_tracer.py +37 -43
- mojentic/tracer/tracer_events.py +8 -2
- mojentic/tracer/tracer_events_spec.py +6 -7
- mojentic/tracer/tracer_system.py +37 -36
- mojentic/tracer/tracer_system_spec.py +21 -6
- mojentic/utils/__init__.py +1 -1
- mojentic/utils/formatting.py +1 -0
- {mojentic-0.8.4.dist-info → mojentic-1.0.0.dist-info}/METADATA +76 -27
- mojentic-1.0.0.dist-info/RECORD +149 -0
- mojentic-0.8.4.dist-info/RECORD +0 -146
- {mojentic-0.8.4.dist-info → mojentic-1.0.0.dist-info}/WHEEL +0 -0
- {mojentic-0.8.4.dist-info → mojentic-1.0.0.dist-info}/licenses/LICENSE.md +0 -0
- {mojentic-0.8.4.dist-info → mojentic-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import pytest
|
|
3
|
-
from unittest.mock import
|
|
2
|
+
from unittest.mock import MagicMock
|
|
4
3
|
|
|
5
4
|
from pydantic import BaseModel, Field
|
|
6
5
|
|
|
7
6
|
from mojentic.agents.async_llm_agent import BaseAsyncLLMAgent
|
|
8
7
|
from mojentic.event import Event
|
|
9
8
|
from mojentic.llm.llm_broker import LLMBroker
|
|
10
|
-
from mojentic.llm.gateways.models import
|
|
9
|
+
from mojentic.llm.gateways.models import MessageRole
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
class TestEvent(Event):
|
|
@@ -47,7 +46,7 @@ async def test_async_llm_agent_init(mock_llm_broker):
|
|
|
47
46
|
behaviour="You are a test assistant.",
|
|
48
47
|
response_model=TestResponse
|
|
49
48
|
)
|
|
50
|
-
|
|
49
|
+
|
|
51
50
|
assert agent.llm == mock_llm_broker
|
|
52
51
|
assert agent.behaviour == "You are a test assistant."
|
|
53
52
|
assert agent.response_model == TestResponse
|
|
@@ -58,7 +57,7 @@ async def test_async_llm_agent_init(mock_llm_broker):
|
|
|
58
57
|
async def test_async_llm_agent_create_initial_messages(async_llm_agent):
|
|
59
58
|
"""Test that the BaseAsyncLLMAgent creates initial messages correctly."""
|
|
60
59
|
messages = async_llm_agent._create_initial_messages()
|
|
61
|
-
|
|
60
|
+
|
|
62
61
|
assert len(messages) == 1
|
|
63
62
|
assert messages[0].role == MessageRole.System
|
|
64
63
|
assert messages[0].content == "You are a test assistant."
|
|
@@ -69,7 +68,7 @@ async def test_async_llm_agent_add_tool(async_llm_agent):
|
|
|
69
68
|
"""Test that the BaseAsyncLLMAgent can add tools."""
|
|
70
69
|
mock_tool = MagicMock()
|
|
71
70
|
async_llm_agent.add_tool(mock_tool)
|
|
72
|
-
|
|
71
|
+
|
|
73
72
|
assert mock_tool in async_llm_agent.tools
|
|
74
73
|
|
|
75
74
|
|
|
@@ -77,10 +76,10 @@ async def test_async_llm_agent_add_tool(async_llm_agent):
|
|
|
77
76
|
async def test_async_llm_agent_generate_response_with_model(async_llm_agent, mock_llm_broker):
|
|
78
77
|
"""Test that the BaseAsyncLLMAgent generates responses with a model."""
|
|
79
78
|
response = await async_llm_agent.generate_response("Test question")
|
|
80
|
-
|
|
79
|
+
|
|
81
80
|
# Verify that generate_object was called
|
|
82
81
|
mock_llm_broker.generate_object.assert_called_once()
|
|
83
|
-
|
|
82
|
+
|
|
84
83
|
# Verify the response
|
|
85
84
|
assert isinstance(response, TestResponse)
|
|
86
85
|
assert response.answer == "Test answer"
|
|
@@ -93,12 +92,12 @@ async def test_async_llm_agent_generate_response_without_model(mock_llm_broker):
|
|
|
93
92
|
llm=mock_llm_broker,
|
|
94
93
|
behaviour="You are a test assistant."
|
|
95
94
|
)
|
|
96
|
-
|
|
95
|
+
|
|
97
96
|
response = await agent.generate_response("Test question")
|
|
98
|
-
|
|
97
|
+
|
|
99
98
|
# Verify that generate was called
|
|
100
99
|
mock_llm_broker.generate.assert_called_once()
|
|
101
|
-
|
|
100
|
+
|
|
102
101
|
# Verify the response
|
|
103
102
|
assert response == "Test response"
|
|
104
103
|
|
|
@@ -107,15 +106,15 @@ async def test_async_llm_agent_generate_response_without_model(mock_llm_broker):
|
|
|
107
106
|
async def test_async_llm_agent_generate_response_with_tools(mock_llm_broker):
|
|
108
107
|
"""Test that the BaseAsyncLLMAgent generates responses with tools."""
|
|
109
108
|
mock_tool = MagicMock()
|
|
110
|
-
|
|
109
|
+
|
|
111
110
|
agent = BaseAsyncLLMAgent(
|
|
112
111
|
llm=mock_llm_broker,
|
|
113
112
|
behaviour="You are a test assistant.",
|
|
114
113
|
tools=[mock_tool]
|
|
115
114
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
115
|
+
|
|
116
|
+
await agent.generate_response("Test question")
|
|
117
|
+
|
|
119
118
|
# Verify that generate was called with tools
|
|
120
119
|
mock_llm_broker.generate.assert_called_once()
|
|
121
120
|
args, kwargs = mock_llm_broker.generate.call_args
|
|
@@ -126,17 +125,17 @@ async def test_async_llm_agent_generate_response_with_tools(mock_llm_broker):
|
|
|
126
125
|
async def test_async_llm_agent_receive_event_async(async_llm_agent):
|
|
127
126
|
"""Test that the BaseAsyncLLMAgent's receive_event_async method works."""
|
|
128
127
|
event = TestEvent(source=str, message="Test message")
|
|
129
|
-
|
|
128
|
+
|
|
130
129
|
# The base implementation should return an empty list
|
|
131
130
|
result = await async_llm_agent.receive_event_async(event)
|
|
132
|
-
|
|
131
|
+
|
|
133
132
|
assert result == []
|
|
134
133
|
|
|
135
134
|
|
|
136
135
|
# Create a subclass for testing the receive_event_async method
|
|
137
136
|
class TestAsyncLLMAgent(BaseAsyncLLMAgent):
|
|
138
137
|
"""A test async LLM agent that implements receive_event_async."""
|
|
139
|
-
|
|
138
|
+
|
|
140
139
|
async def receive_event_async(self, event):
|
|
141
140
|
if isinstance(event, TestEvent):
|
|
142
141
|
response = await self.generate_response(event.message)
|
|
@@ -156,11 +155,11 @@ async def test_subclass_async_llm_agent_receive_event_async(mock_llm_broker):
|
|
|
156
155
|
behaviour="You are a test assistant.",
|
|
157
156
|
response_model=TestResponse
|
|
158
157
|
)
|
|
159
|
-
|
|
158
|
+
|
|
160
159
|
event = TestEvent(source=str, message="Test message")
|
|
161
|
-
|
|
160
|
+
|
|
162
161
|
result = await agent.receive_event_async(event)
|
|
163
|
-
|
|
162
|
+
|
|
164
163
|
assert len(result) == 1
|
|
165
164
|
assert isinstance(result[0], TestEvent)
|
|
166
|
-
assert result[0].message == "Response: Test answer"
|
|
165
|
+
assert result[0].message == "Response: Test answer"
|
|
@@ -11,8 +11,8 @@ class BaseAsyncAgent:
|
|
|
11
11
|
|
|
12
12
|
async def receive_event_async(self, event: Event) -> List[Event]:
|
|
13
13
|
"""
|
|
14
|
-
receive_event_async is the method that all async agents must implement.
|
|
15
|
-
events as output.
|
|
14
|
+
receive_event_async is the method that all async agents must implement.
|
|
15
|
+
It takes an event as input and returns a list of events as output.
|
|
16
16
|
|
|
17
17
|
In this way, you can perform work based on the event, and generate whatever subsequent events may need to be
|
|
18
18
|
processed next.
|
|
@@ -55,8 +55,12 @@ class BaseLLMAgentWithMemory(BaseLLMAgent):
|
|
|
55
55
|
def _create_initial_messages(self):
|
|
56
56
|
messages = super()._create_initial_messages()
|
|
57
57
|
messages.extend([
|
|
58
|
-
LLMMessage(
|
|
59
|
-
|
|
58
|
+
LLMMessage(
|
|
59
|
+
content=(f"This is what you remember:\n"
|
|
60
|
+
f"{json.dumps(self.memory.get_working_memory(), indent=2)}"
|
|
61
|
+
f"\n\nRemember anything new you learn by storing it "
|
|
62
|
+
f"to your working memory in your response.")
|
|
63
|
+
),
|
|
60
64
|
LLMMessage(role=MessageRole.User, content=self.instructions),
|
|
61
65
|
])
|
|
62
66
|
return messages
|
|
@@ -43,9 +43,13 @@ class IterativeProblemSolver:
|
|
|
43
43
|
self.available_tools = available_tools or []
|
|
44
44
|
self.chat = ChatSession(
|
|
45
45
|
llm=llm,
|
|
46
|
-
system_prompt=system_prompt or
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
system_prompt=system_prompt or (
|
|
47
|
+
"You are a problem-solving assistant that can solve complex problems step by step. "
|
|
48
|
+
"You analyze problems, break them down into smaller parts, "
|
|
49
|
+
"and solve them systematically. "
|
|
50
|
+
"If you cannot solve a problem completely in one step, "
|
|
51
|
+
"you make progress and identify what to do next."
|
|
52
|
+
),
|
|
49
53
|
tools=self.available_tools,
|
|
50
54
|
)
|
|
51
55
|
|
|
@@ -87,7 +91,8 @@ class IterativeProblemSolver:
|
|
|
87
91
|
break
|
|
88
92
|
|
|
89
93
|
result = self.chat.send(
|
|
90
|
-
"Summarize the final result, and only the final result,
|
|
94
|
+
"Summarize the final result, and only the final result, "
|
|
95
|
+
"without commenting on the process by which you achieved it.")
|
|
91
96
|
|
|
92
97
|
return result
|
|
93
98
|
|
|
@@ -111,7 +116,8 @@ class IterativeProblemSolver:
|
|
|
111
116
|
Given the user request:
|
|
112
117
|
{problem}
|
|
113
118
|
|
|
114
|
-
Use the tools at your disposal to act on their request.
|
|
119
|
+
Use the tools at your disposal to act on their request.
|
|
120
|
+
You may wish to create a step-by-step plan for more complicated requests.
|
|
115
121
|
|
|
116
122
|
If you cannot provide an answer, say only "FAIL".
|
|
117
123
|
If you have the answer, say only "DONE".
|
|
@@ -35,7 +35,6 @@ class GoalSubmittedEvent(SolverEvent):
|
|
|
35
35
|
"""
|
|
36
36
|
Event triggered when a problem is submitted for solving.
|
|
37
37
|
"""
|
|
38
|
-
pass
|
|
39
38
|
|
|
40
39
|
|
|
41
40
|
class IterationCompletedEvent(SolverEvent):
|
|
@@ -49,21 +48,18 @@ class GoalAchievedEvent(SolverEvent):
|
|
|
49
48
|
"""
|
|
50
49
|
Event triggered when a problem is solved.
|
|
51
50
|
"""
|
|
52
|
-
pass
|
|
53
51
|
|
|
54
52
|
|
|
55
53
|
class GoalFailedEvent(SolverEvent):
|
|
56
54
|
"""
|
|
57
55
|
Event triggered when a problem cannot be solved.
|
|
58
56
|
"""
|
|
59
|
-
pass
|
|
60
57
|
|
|
61
58
|
|
|
62
59
|
class TimeoutEvent(SolverEvent):
|
|
63
60
|
"""
|
|
64
61
|
Event triggered when the problem-solving process times out.
|
|
65
62
|
"""
|
|
66
|
-
pass
|
|
67
63
|
|
|
68
64
|
|
|
69
65
|
class EventEmitter:
|
|
@@ -139,7 +135,8 @@ class SimpleRecursiveAgent:
|
|
|
139
135
|
emitter: EventEmitter
|
|
140
136
|
chat: ChatSession
|
|
141
137
|
|
|
142
|
-
def __init__(self, llm: LLMBroker, available_tools: Optional[List[LLMTool]] = None,
|
|
138
|
+
def __init__(self, llm: LLMBroker, available_tools: Optional[List[LLMTool]] = None,
|
|
139
|
+
max_iterations: int = 5, system_prompt: Optional[str] = None):
|
|
143
140
|
"""
|
|
144
141
|
Initialize the SimpleRecursiveAgent.
|
|
145
142
|
|
|
@@ -160,9 +157,12 @@ class SimpleRecursiveAgent:
|
|
|
160
157
|
# Initialize the chat session
|
|
161
158
|
self.chat = ChatSession(
|
|
162
159
|
llm=llm,
|
|
163
|
-
system_prompt=
|
|
164
|
-
|
|
165
|
-
|
|
160
|
+
system_prompt=(
|
|
161
|
+
system_prompt or
|
|
162
|
+
"You are a problem-solving assistant that can solve complex problems step by step. "
|
|
163
|
+
"You analyze problems, break them down into smaller parts, and solve them systematically. "
|
|
164
|
+
"If you cannot solve a problem completely in one step, you make progress and identify what to do next."
|
|
165
|
+
),
|
|
166
166
|
tools=self.available_tools
|
|
167
167
|
)
|
|
168
168
|
|
|
@@ -207,7 +207,7 @@ class SimpleRecursiveAgent:
|
|
|
207
207
|
try:
|
|
208
208
|
return await asyncio.wait_for(solution_future, timeout=300) # 5 minutes timeout
|
|
209
209
|
except asyncio.TimeoutError:
|
|
210
|
-
timeout_message =
|
|
210
|
+
timeout_message = "Timeout: Could not solve the problem within 300 seconds."
|
|
211
211
|
if not solution_future.done():
|
|
212
212
|
state.solution = timeout_message
|
|
213
213
|
state.is_complete = True
|
|
@@ -277,7 +277,8 @@ class SimpleRecursiveAgent:
|
|
|
277
277
|
Given the user request:
|
|
278
278
|
{state.goal}
|
|
279
279
|
|
|
280
|
-
Use the tools at your disposal to act on their request.
|
|
280
|
+
Use the tools at your disposal to act on their request.
|
|
281
|
+
You may wish to create a step-by-step plan for more complicated requests.
|
|
281
282
|
|
|
282
283
|
If you cannot provide an answer, say only "FAIL".
|
|
283
284
|
If you have the answer, say only "DONE".
|
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for the SimpleRecursiveAgent class.
|
|
3
|
+
|
|
4
|
+
This module contains comprehensive tests for the SimpleRecursiveAgent,
|
|
5
|
+
including event handling, async operation, iteration logic, and edge cases.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import pytest
|
|
10
|
+
from unittest.mock import MagicMock, AsyncMock
|
|
11
|
+
|
|
12
|
+
from mojentic.agents.simple_recursive_agent import (
|
|
13
|
+
SimpleRecursiveAgent,
|
|
14
|
+
GoalState,
|
|
15
|
+
GoalSubmittedEvent,
|
|
16
|
+
IterationCompletedEvent,
|
|
17
|
+
GoalAchievedEvent,
|
|
18
|
+
GoalFailedEvent,
|
|
19
|
+
TimeoutEvent,
|
|
20
|
+
EventEmitter,
|
|
21
|
+
)
|
|
22
|
+
from mojentic.llm.llm_broker import LLMBroker
|
|
23
|
+
from mojentic.llm.chat_session import ChatSession
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def mock_llm_broker():
|
|
28
|
+
"""Create a mock LLM broker for testing."""
|
|
29
|
+
mock_broker = MagicMock(spec=LLMBroker)
|
|
30
|
+
return mock_broker
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def mock_chat_session(mocker):
|
|
35
|
+
"""Create a mock ChatSession for testing."""
|
|
36
|
+
mock_session = mocker.Mock(spec=ChatSession)
|
|
37
|
+
mock_session.send.return_value = "DONE - Test solution"
|
|
38
|
+
return mock_session
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DescribeEventEmitter:
|
|
42
|
+
"""Tests for the EventEmitter class."""
|
|
43
|
+
|
|
44
|
+
def should_allow_subscribing_to_events(self):
|
|
45
|
+
"""Test that subscribers can be added to event types."""
|
|
46
|
+
emitter = EventEmitter()
|
|
47
|
+
callback = MagicMock()
|
|
48
|
+
|
|
49
|
+
unsubscribe = emitter.subscribe(GoalSubmittedEvent, callback)
|
|
50
|
+
|
|
51
|
+
assert callable(unsubscribe)
|
|
52
|
+
assert GoalSubmittedEvent in emitter.subscribers
|
|
53
|
+
assert callback in emitter.subscribers[GoalSubmittedEvent]
|
|
54
|
+
|
|
55
|
+
def should_allow_unsubscribing_from_events(self):
|
|
56
|
+
"""Test that subscribers can be removed from event types."""
|
|
57
|
+
emitter = EventEmitter()
|
|
58
|
+
callback = MagicMock()
|
|
59
|
+
|
|
60
|
+
unsubscribe = emitter.subscribe(GoalSubmittedEvent, callback)
|
|
61
|
+
unsubscribe()
|
|
62
|
+
|
|
63
|
+
assert callback not in emitter.subscribers[GoalSubmittedEvent]
|
|
64
|
+
|
|
65
|
+
def should_emit_events_to_subscribers(self):
|
|
66
|
+
"""Test that events are delivered to all subscribers."""
|
|
67
|
+
emitter = EventEmitter()
|
|
68
|
+
callback1 = MagicMock()
|
|
69
|
+
callback2 = MagicMock()
|
|
70
|
+
state = GoalState(goal="test")
|
|
71
|
+
event = GoalSubmittedEvent(state=state)
|
|
72
|
+
|
|
73
|
+
emitter.subscribe(GoalSubmittedEvent, callback1)
|
|
74
|
+
emitter.subscribe(GoalSubmittedEvent, callback2)
|
|
75
|
+
emitter.emit(event)
|
|
76
|
+
|
|
77
|
+
callback1.assert_called_once_with(event)
|
|
78
|
+
callback2.assert_called_once_with(event)
|
|
79
|
+
|
|
80
|
+
@pytest.mark.asyncio
|
|
81
|
+
async def should_handle_async_callbacks(self):
|
|
82
|
+
"""Test that async callbacks are properly handled."""
|
|
83
|
+
emitter = EventEmitter()
|
|
84
|
+
async_callback = AsyncMock()
|
|
85
|
+
state = GoalState(goal="test")
|
|
86
|
+
event = GoalSubmittedEvent(state=state)
|
|
87
|
+
|
|
88
|
+
emitter.subscribe(GoalSubmittedEvent, async_callback)
|
|
89
|
+
emitter.emit(event)
|
|
90
|
+
|
|
91
|
+
# Give the event loop a chance to process the async callback
|
|
92
|
+
await asyncio.sleep(0.01)
|
|
93
|
+
|
|
94
|
+
# The async callback should have been called
|
|
95
|
+
assert async_callback.called
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DescribeGoalState:
|
|
99
|
+
"""Tests for the GoalState class."""
|
|
100
|
+
|
|
101
|
+
def should_initialize_with_defaults(self):
|
|
102
|
+
"""Test that GoalState initializes with correct defaults."""
|
|
103
|
+
state = GoalState(goal="test goal")
|
|
104
|
+
|
|
105
|
+
assert state.goal == "test goal"
|
|
106
|
+
assert state.iteration == 0
|
|
107
|
+
assert state.max_iterations == 5
|
|
108
|
+
assert state.solution is None
|
|
109
|
+
assert state.is_complete is False
|
|
110
|
+
|
|
111
|
+
def should_allow_custom_max_iterations(self):
|
|
112
|
+
"""Test that max_iterations can be customized."""
|
|
113
|
+
state = GoalState(goal="test", max_iterations=10)
|
|
114
|
+
|
|
115
|
+
assert state.max_iterations == 10
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class DescribeSimpleRecursiveAgent:
|
|
119
|
+
"""Tests for the SimpleRecursiveAgent class."""
|
|
120
|
+
|
|
121
|
+
def should_initialize_with_required_parameters(self, mock_llm_broker):
|
|
122
|
+
"""Test that the agent initializes with required parameters."""
|
|
123
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
124
|
+
|
|
125
|
+
assert agent.llm == mock_llm_broker
|
|
126
|
+
assert agent.max_iterations == 5
|
|
127
|
+
assert agent.available_tools == []
|
|
128
|
+
assert isinstance(agent.emitter, EventEmitter)
|
|
129
|
+
assert isinstance(agent.chat, ChatSession)
|
|
130
|
+
|
|
131
|
+
def should_initialize_with_custom_parameters(self, mock_llm_broker):
|
|
132
|
+
"""Test that the agent accepts custom parameters."""
|
|
133
|
+
from mojentic.llm.tools.llm_tool import LLMTool
|
|
134
|
+
|
|
135
|
+
mock_tool = MagicMock(spec=LLMTool)
|
|
136
|
+
agent = SimpleRecursiveAgent(
|
|
137
|
+
llm=mock_llm_broker,
|
|
138
|
+
max_iterations=10,
|
|
139
|
+
available_tools=[mock_tool],
|
|
140
|
+
system_prompt="Custom prompt"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
assert agent.max_iterations == 10
|
|
144
|
+
assert len(agent.available_tools) == 1
|
|
145
|
+
assert agent.available_tools[0] == mock_tool
|
|
146
|
+
|
|
147
|
+
def should_have_event_handlers_registered(self, mock_llm_broker):
|
|
148
|
+
"""Test that event handlers are registered during initialization."""
|
|
149
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
150
|
+
|
|
151
|
+
assert GoalSubmittedEvent in agent.emitter.subscribers
|
|
152
|
+
assert IterationCompletedEvent in agent.emitter.subscribers
|
|
153
|
+
|
|
154
|
+
@pytest.mark.asyncio
|
|
155
|
+
async def should_solve_problem_with_immediate_success(
|
|
156
|
+
self, mock_llm_broker, mocker
|
|
157
|
+
):
|
|
158
|
+
"""Test that the agent solves a problem that succeeds immediately."""
|
|
159
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker, max_iterations=3)
|
|
160
|
+
|
|
161
|
+
# Mock the chat session to return DONE immediately
|
|
162
|
+
mocker.patch.object(
|
|
163
|
+
agent.chat, "send", return_value="DONE - Solution found"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
result = await agent.solve("Test problem")
|
|
167
|
+
|
|
168
|
+
assert "DONE" in result
|
|
169
|
+
assert "Solution found" in result
|
|
170
|
+
|
|
171
|
+
@pytest.mark.asyncio
|
|
172
|
+
async def should_solve_problem_with_multiple_iterations(
|
|
173
|
+
self, mock_llm_broker, mocker
|
|
174
|
+
):
|
|
175
|
+
"""Test that the agent handles multiple iterations before success."""
|
|
176
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker, max_iterations=3)
|
|
177
|
+
|
|
178
|
+
# Mock the chat session to return different responses
|
|
179
|
+
responses = [
|
|
180
|
+
"Working on it...",
|
|
181
|
+
"Still working...",
|
|
182
|
+
"DONE - Final solution",
|
|
183
|
+
]
|
|
184
|
+
mocker.patch.object(
|
|
185
|
+
agent.chat, "send", side_effect=responses
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
result = await agent.solve("Test problem")
|
|
189
|
+
|
|
190
|
+
assert "DONE" in result
|
|
191
|
+
assert "Final solution" in result
|
|
192
|
+
|
|
193
|
+
@pytest.mark.asyncio
|
|
194
|
+
async def should_handle_explicit_failure(self, mock_llm_broker, mocker):
|
|
195
|
+
"""Test that the agent handles explicit FAIL responses."""
|
|
196
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker, max_iterations=3)
|
|
197
|
+
|
|
198
|
+
# Mock the chat session to return FAIL
|
|
199
|
+
mocker.patch.object(
|
|
200
|
+
agent.chat, "send", return_value="FAIL - Cannot solve this problem"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
result = await agent.solve("Impossible problem")
|
|
204
|
+
|
|
205
|
+
assert "Failed to solve" in result
|
|
206
|
+
assert "Cannot solve this problem" in result
|
|
207
|
+
|
|
208
|
+
@pytest.mark.asyncio
|
|
209
|
+
async def should_handle_max_iterations_reached(
|
|
210
|
+
self, mock_llm_broker, mocker
|
|
211
|
+
):
|
|
212
|
+
"""Test that the agent stops at max_iterations."""
|
|
213
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker, max_iterations=2)
|
|
214
|
+
|
|
215
|
+
# Mock the chat session to never return DONE or FAIL
|
|
216
|
+
mocker.patch.object(
|
|
217
|
+
agent.chat, "send", return_value="Still working on it..."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
result = await agent.solve("Complex problem")
|
|
221
|
+
|
|
222
|
+
assert "Best solution after 2 iterations" in result
|
|
223
|
+
|
|
224
|
+
@pytest.mark.asyncio
|
|
225
|
+
async def should_emit_goal_submitted_event(self, mock_llm_broker, mocker):
|
|
226
|
+
"""Test that GoalSubmittedEvent is emitted."""
|
|
227
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
228
|
+
event_received = []
|
|
229
|
+
|
|
230
|
+
def capture_event(event):
|
|
231
|
+
event_received.append(event)
|
|
232
|
+
|
|
233
|
+
agent.emitter.subscribe(GoalSubmittedEvent, capture_event)
|
|
234
|
+
mocker.patch.object(
|
|
235
|
+
agent.chat, "send", return_value="DONE - Solution"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
await agent.solve("Test problem")
|
|
239
|
+
|
|
240
|
+
assert len(event_received) == 1
|
|
241
|
+
assert isinstance(event_received[0], GoalSubmittedEvent)
|
|
242
|
+
assert event_received[0].state.goal == "Test problem"
|
|
243
|
+
|
|
244
|
+
@pytest.mark.asyncio
|
|
245
|
+
async def should_emit_iteration_completed_events(
|
|
246
|
+
self, mock_llm_broker, mocker
|
|
247
|
+
):
|
|
248
|
+
"""Test that IterationCompletedEvent is emitted for each iteration."""
|
|
249
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker, max_iterations=3)
|
|
250
|
+
events_received = []
|
|
251
|
+
|
|
252
|
+
def capture_event(event):
|
|
253
|
+
events_received.append(event)
|
|
254
|
+
|
|
255
|
+
agent.emitter.subscribe(IterationCompletedEvent, capture_event)
|
|
256
|
+
responses = ["Working...", "Still working...", "DONE - Solution"]
|
|
257
|
+
mocker.patch.object(
|
|
258
|
+
agent.chat, "send", side_effect=responses
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
await agent.solve("Test problem")
|
|
262
|
+
|
|
263
|
+
assert len(events_received) == 3
|
|
264
|
+
assert all(isinstance(e, IterationCompletedEvent) for e in events_received)
|
|
265
|
+
|
|
266
|
+
@pytest.mark.asyncio
|
|
267
|
+
async def should_emit_goal_achieved_event(self, mock_llm_broker, mocker):
|
|
268
|
+
"""Test that GoalAchievedEvent is emitted on success."""
|
|
269
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
270
|
+
event_received = []
|
|
271
|
+
|
|
272
|
+
def capture_event(event):
|
|
273
|
+
event_received.append(event)
|
|
274
|
+
|
|
275
|
+
agent.emitter.subscribe(GoalAchievedEvent, capture_event)
|
|
276
|
+
mocker.patch.object(
|
|
277
|
+
agent.chat, "send", return_value="DONE - Solution"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
await agent.solve("Test problem")
|
|
281
|
+
|
|
282
|
+
assert len(event_received) == 1
|
|
283
|
+
assert isinstance(event_received[0], GoalAchievedEvent)
|
|
284
|
+
|
|
285
|
+
@pytest.mark.asyncio
|
|
286
|
+
async def should_emit_goal_failed_event(self, mock_llm_broker, mocker):
|
|
287
|
+
"""Test that GoalFailedEvent is emitted on failure."""
|
|
288
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
289
|
+
event_received = []
|
|
290
|
+
|
|
291
|
+
def capture_event(event):
|
|
292
|
+
event_received.append(event)
|
|
293
|
+
|
|
294
|
+
agent.emitter.subscribe(GoalFailedEvent, capture_event)
|
|
295
|
+
mocker.patch.object(
|
|
296
|
+
agent.chat, "send", return_value="FAIL - Cannot solve"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
await agent.solve("Impossible problem")
|
|
300
|
+
|
|
301
|
+
assert len(event_received) == 1
|
|
302
|
+
assert isinstance(event_received[0], GoalFailedEvent)
|
|
303
|
+
|
|
304
|
+
@pytest.mark.asyncio
|
|
305
|
+
async def should_handle_timeout(self, mock_llm_broker, mocker):
|
|
306
|
+
"""Test that the agent handles timeout scenarios."""
|
|
307
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
308
|
+
|
|
309
|
+
# Mock the chat session to delay long enough to trigger timeout
|
|
310
|
+
async def slow_send(*args, **kwargs):
|
|
311
|
+
await asyncio.sleep(10)
|
|
312
|
+
return "Never reached"
|
|
313
|
+
|
|
314
|
+
mocker.patch.object(agent, "_generate", side_effect=slow_send)
|
|
315
|
+
|
|
316
|
+
# Override the timeout to be very short for testing
|
|
317
|
+
async def quick_timeout_solve(problem):
|
|
318
|
+
solution_future = asyncio.Future()
|
|
319
|
+
state = GoalState(goal=problem, max_iterations=agent.max_iterations)
|
|
320
|
+
|
|
321
|
+
async def handle_solution_event(event):
|
|
322
|
+
if not solution_future.done():
|
|
323
|
+
solution_future.set_result(event.state.solution)
|
|
324
|
+
|
|
325
|
+
agent.emitter.subscribe(GoalAchievedEvent, handle_solution_event)
|
|
326
|
+
agent.emitter.subscribe(GoalFailedEvent, handle_solution_event)
|
|
327
|
+
agent.emitter.subscribe(TimeoutEvent, handle_solution_event)
|
|
328
|
+
|
|
329
|
+
agent.emitter.emit(GoalSubmittedEvent(state=state))
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
return await asyncio.wait_for(solution_future, timeout=0.1)
|
|
333
|
+
except asyncio.TimeoutError:
|
|
334
|
+
timeout_message = "Timeout: Could not solve the problem within 0.1 seconds."
|
|
335
|
+
if not solution_future.done():
|
|
336
|
+
state.solution = timeout_message
|
|
337
|
+
state.is_complete = True
|
|
338
|
+
agent.emitter.emit(TimeoutEvent(state=state))
|
|
339
|
+
return timeout_message
|
|
340
|
+
|
|
341
|
+
result = await quick_timeout_solve("Test problem")
|
|
342
|
+
|
|
343
|
+
assert "Timeout" in result
|
|
344
|
+
|
|
345
|
+
@pytest.mark.asyncio
|
|
346
|
+
async def should_use_asyncio_to_thread_for_chat_send(
|
|
347
|
+
self, mock_llm_broker, mocker
|
|
348
|
+
):
|
|
349
|
+
"""Test that _generate uses asyncio.to_thread for synchronous chat.send."""
|
|
350
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
351
|
+
|
|
352
|
+
# Mock asyncio.to_thread
|
|
353
|
+
mock_to_thread = mocker.patch("asyncio.to_thread")
|
|
354
|
+
mock_to_thread.return_value = "Test response"
|
|
355
|
+
|
|
356
|
+
result = await agent._generate("Test prompt")
|
|
357
|
+
|
|
358
|
+
mock_to_thread.assert_called_once_with(agent.chat.send, "Test prompt")
|
|
359
|
+
assert result == "Test response"
|
|
360
|
+
|
|
361
|
+
@pytest.mark.asyncio
|
|
362
|
+
async def should_handle_case_insensitive_done_keyword(
|
|
363
|
+
self, mock_llm_broker, mocker
|
|
364
|
+
):
|
|
365
|
+
"""Test that DONE keyword is case-insensitive."""
|
|
366
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
367
|
+
|
|
368
|
+
test_cases = ["done - solution", "DoNe - solution", "DONE - solution"]
|
|
369
|
+
|
|
370
|
+
for response_text in test_cases:
|
|
371
|
+
mocker.patch.object(agent.chat, "send", return_value=response_text)
|
|
372
|
+
result = await agent.solve("Test problem")
|
|
373
|
+
assert response_text in result
|
|
374
|
+
|
|
375
|
+
@pytest.mark.asyncio
|
|
376
|
+
async def should_handle_case_insensitive_fail_keyword(
|
|
377
|
+
self, mock_llm_broker, mocker
|
|
378
|
+
):
|
|
379
|
+
"""Test that FAIL keyword is case-insensitive."""
|
|
380
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
381
|
+
|
|
382
|
+
test_cases = ["fail - error", "FaIl - error", "FAIL - error"]
|
|
383
|
+
|
|
384
|
+
for response_text in test_cases:
|
|
385
|
+
mocker.patch.object(agent.chat, "send", return_value=response_text)
|
|
386
|
+
result = await agent.solve("Test problem")
|
|
387
|
+
assert "Failed to solve" in result
|
|
388
|
+
|
|
389
|
+
@pytest.mark.asyncio
|
|
390
|
+
async def should_include_goal_in_prompt(self, mock_llm_broker, mocker):
|
|
391
|
+
"""Test that the user's goal is included in the prompt."""
|
|
392
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker)
|
|
393
|
+
|
|
394
|
+
captured_prompts = []
|
|
395
|
+
|
|
396
|
+
def capture_prompt(prompt):
|
|
397
|
+
captured_prompts.append(prompt)
|
|
398
|
+
return "DONE - Solution"
|
|
399
|
+
|
|
400
|
+
mocker.patch.object(agent.chat, "send", side_effect=capture_prompt)
|
|
401
|
+
|
|
402
|
+
await agent.solve("Find the meaning of life")
|
|
403
|
+
|
|
404
|
+
assert len(captured_prompts) > 0
|
|
405
|
+
assert "Find the meaning of life" in captured_prompts[0]
|
|
406
|
+
|
|
407
|
+
@pytest.mark.asyncio
|
|
408
|
+
async def should_increment_iteration_count(self, mock_llm_broker, mocker):
|
|
409
|
+
"""Test that iteration count is properly incremented."""
|
|
410
|
+
agent = SimpleRecursiveAgent(llm=mock_llm_broker, max_iterations=3)
|
|
411
|
+
iterations_seen = []
|
|
412
|
+
|
|
413
|
+
def track_iteration(event):
|
|
414
|
+
iterations_seen.append(event.state.iteration)
|
|
415
|
+
|
|
416
|
+
agent.emitter.subscribe(IterationCompletedEvent, track_iteration)
|
|
417
|
+
|
|
418
|
+
responses = ["Working...", "Still working...", "DONE - Solution"]
|
|
419
|
+
mocker.patch.object(agent.chat, "send", side_effect=responses)
|
|
420
|
+
|
|
421
|
+
await agent.solve("Test problem")
|
|
422
|
+
|
|
423
|
+
assert iterations_seen == [1, 2, 3]
|