aiagents4pharma 1.19.1__py3-none-any.whl → 1.20.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +129 -73
- aiagents4pharma/talk2scholars/agents/s2_agent.py +2 -1
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +10 -31
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +16 -60
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +9 -8
- aiagents4pharma/talk2scholars/tests/test_integration.py +237 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +180 -0
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +138 -0
- aiagents4pharma/talk2scholars/tests/{test_langgraph.py → test_s2_tools.py} +79 -151
- aiagents4pharma/talk2scholars/tests/test_state.py +14 -0
- aiagents4pharma/talk2scholars/tools/s2/display_results.py +33 -8
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +10 -23
- aiagents4pharma/talk2scholars/tools/s2/search.py +10 -29
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +4 -29
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/METADATA +19 -3
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/RECORD +20 -15
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
|
|
1
|
+
"""
|
2
|
+
Integration tests for talk2scholars system.
|
3
|
+
|
4
|
+
These tests ensure that:
|
5
|
+
1. The main agent and sub-agent work together.
|
6
|
+
2. The agents correctly interact with tools (search, recommendations).
|
7
|
+
3. The full pipeline processes queries and updates state correctly.
|
8
|
+
"""
|
9
|
+
|
10
|
+
# pylint: disable=redefined-outer-name
|
11
|
+
from unittest.mock import patch, Mock
|
12
|
+
import pytest
|
13
|
+
from langchain_core.messages import HumanMessage
|
14
|
+
from ..agents.main_agent import get_app as get_main_app
|
15
|
+
from ..agents.s2_agent import get_app as get_s2_app
|
16
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
17
|
+
|
18
|
+
|
19
|
+
@pytest.fixture(autouse=True)
|
20
|
+
def mock_hydra():
|
21
|
+
"""Mock Hydra configuration to prevent external dependencies."""
|
22
|
+
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
23
|
+
cfg_mock = Mock()
|
24
|
+
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
25
|
+
cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test main agent prompt"
|
26
|
+
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
|
27
|
+
cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test s2 agent prompt"
|
28
|
+
mock_compose.return_value = cfg_mock
|
29
|
+
yield mock_compose
|
30
|
+
|
31
|
+
|
32
|
+
@pytest.fixture(autouse=True)
|
33
|
+
def mock_tools():
|
34
|
+
"""Mock tools to prevent execution of real API calls."""
|
35
|
+
with (
|
36
|
+
patch(
|
37
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
|
38
|
+
) as mock_s2_search,
|
39
|
+
patch(
|
40
|
+
"aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
|
41
|
+
) as mock_s2_display,
|
42
|
+
patch(
|
43
|
+
"aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
|
44
|
+
"get_single_paper_recommendations"
|
45
|
+
) as mock_s2_single_rec,
|
46
|
+
patch(
|
47
|
+
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
|
48
|
+
"get_multi_paper_recommendations"
|
49
|
+
) as mock_s2_multi_rec,
|
50
|
+
):
|
51
|
+
|
52
|
+
mock_s2_search.return_value = {"papers": {"id123": "Mock Paper"}}
|
53
|
+
mock_s2_display.return_value = "Displaying Mock Results"
|
54
|
+
mock_s2_single_rec.return_value = {"recommendations": ["Paper A", "Paper B"]}
|
55
|
+
mock_s2_multi_rec.return_value = {
|
56
|
+
"multi_recommendations": ["Paper X", "Paper Y"]
|
57
|
+
}
|
58
|
+
|
59
|
+
yield {
|
60
|
+
"search_tool": mock_s2_search,
|
61
|
+
"display_results": mock_s2_display,
|
62
|
+
"single_paper_rec": mock_s2_single_rec,
|
63
|
+
"multi_paper_rec": mock_s2_multi_rec,
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
def test_full_workflow():
|
68
|
+
"""Test the full workflow from main agent to S2 agent."""
|
69
|
+
thread_id = "test_thread"
|
70
|
+
main_app = get_main_app(thread_id)
|
71
|
+
|
72
|
+
# Define expected mock response with the actual structure
|
73
|
+
expected_paper = {
|
74
|
+
"530a059cb48477ad1e3d4f8f4b153274c8997332": {
|
75
|
+
"Title": "Explainable Artificial Intelligence",
|
76
|
+
"Abstract": None,
|
77
|
+
"Citation Count": 5544,
|
78
|
+
"Year": "2024",
|
79
|
+
"URL": "https://example.com/paper",
|
80
|
+
}
|
81
|
+
}
|
82
|
+
|
83
|
+
# Mock the search tool instead of the app
|
84
|
+
with patch(
|
85
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
|
86
|
+
return_value={"papers": expected_paper},
|
87
|
+
):
|
88
|
+
state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
89
|
+
result = main_app.invoke(
|
90
|
+
state,
|
91
|
+
config={
|
92
|
+
"configurable": {
|
93
|
+
"thread_id": thread_id,
|
94
|
+
"checkpoint_ns": "test_ns",
|
95
|
+
"checkpoint_id": "test_checkpoint",
|
96
|
+
}
|
97
|
+
},
|
98
|
+
)
|
99
|
+
|
100
|
+
# Check values
|
101
|
+
assert "papers" in result
|
102
|
+
assert "messages" in result
|
103
|
+
assert len(result["papers"]) > 0
|
104
|
+
|
105
|
+
|
106
|
+
def test_s2_agent_execution():
|
107
|
+
"""Test if the S2 agent processes requests correctly and updates state."""
|
108
|
+
thread_id = "test_thread"
|
109
|
+
s2_app = get_s2_app(thread_id)
|
110
|
+
|
111
|
+
state = Talk2Scholars(messages=[HumanMessage(content="Get recommendations")])
|
112
|
+
|
113
|
+
result = s2_app.invoke(
|
114
|
+
state,
|
115
|
+
config={
|
116
|
+
"configurable": {
|
117
|
+
"thread_id": thread_id,
|
118
|
+
"checkpoint_ns": "test_ns",
|
119
|
+
"checkpoint_id": "test_checkpoint",
|
120
|
+
}
|
121
|
+
},
|
122
|
+
)
|
123
|
+
|
124
|
+
assert "messages" in result
|
125
|
+
assert "multi_papers" in result
|
126
|
+
assert result["multi_papers"] is not None
|
127
|
+
|
128
|
+
|
129
|
+
def test_tool_integration(mock_tools):
|
130
|
+
"""Test if the tools interact correctly with the workflow."""
|
131
|
+
thread_id = "test_thread"
|
132
|
+
s2_app = get_s2_app(thread_id)
|
133
|
+
|
134
|
+
state = Talk2Scholars(
|
135
|
+
messages=[HumanMessage(content="Search for AI ethics papers")]
|
136
|
+
)
|
137
|
+
|
138
|
+
mock_paper_id = "11159bdb213aaa243916f42f576396d483ba474b"
|
139
|
+
mock_response = {
|
140
|
+
"papers": {
|
141
|
+
mock_paper_id: {
|
142
|
+
"Title": "Mock AI Ethics Paper",
|
143
|
+
"Abstract": "A study on AI ethics",
|
144
|
+
"Citation Count": 100,
|
145
|
+
"URL": "https://example.com/mock-paper",
|
146
|
+
}
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
# Update both the fixture mock and patch the actual tool
|
151
|
+
mock_tools["search_tool"].return_value = {"papers": mock_response["papers"]}
|
152
|
+
|
153
|
+
with patch(
|
154
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
|
155
|
+
return_value={"papers": mock_response["papers"]},
|
156
|
+
):
|
157
|
+
result = s2_app.invoke(
|
158
|
+
state,
|
159
|
+
config={
|
160
|
+
"configurable": {
|
161
|
+
"thread_id": thread_id,
|
162
|
+
"checkpoint_ns": "test_ns",
|
163
|
+
"checkpoint_id": "test_checkpoint",
|
164
|
+
}
|
165
|
+
},
|
166
|
+
)
|
167
|
+
|
168
|
+
assert "papers" in result
|
169
|
+
assert len(result["papers"]) > 0 # Verify we have papers
|
170
|
+
assert isinstance(result["papers"], dict) # Verify it's a dictionary
|
171
|
+
|
172
|
+
|
173
|
+
def test_empty_query():
|
174
|
+
"""Test how the system handles an empty query."""
|
175
|
+
thread_id = "test_thread"
|
176
|
+
main_app = get_main_app(thread_id)
|
177
|
+
|
178
|
+
state = Talk2Scholars(messages=[HumanMessage(content="")])
|
179
|
+
|
180
|
+
# Mock the s2_agent app
|
181
|
+
mock_s2_app = get_s2_app(thread_id)
|
182
|
+
|
183
|
+
with patch(
|
184
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
|
185
|
+
return_value=mock_s2_app,
|
186
|
+
):
|
187
|
+
result = main_app.invoke(
|
188
|
+
state,
|
189
|
+
config={
|
190
|
+
"configurable": {
|
191
|
+
"thread_id": thread_id,
|
192
|
+
"checkpoint_ns": "test_ns",
|
193
|
+
"checkpoint_id": "test_checkpoint",
|
194
|
+
}
|
195
|
+
},
|
196
|
+
)
|
197
|
+
|
198
|
+
assert "messages" in result
|
199
|
+
last_message = result["messages"][-1].content.lower()
|
200
|
+
assert any(
|
201
|
+
phrase in last_message
|
202
|
+
for phrase in ["no valid input", "how can i assist", "please provide a query"]
|
203
|
+
)
|
204
|
+
|
205
|
+
|
206
|
+
def test_api_failure_handling():
|
207
|
+
"""Test if the system gracefully handles an API failure."""
|
208
|
+
thread_id = "test_thread"
|
209
|
+
s2_app = get_s2_app(thread_id)
|
210
|
+
|
211
|
+
expected_error = "API Timeout: Connection failed"
|
212
|
+
with patch("requests.get", side_effect=Exception(expected_error)):
|
213
|
+
state = Talk2Scholars(messages=[HumanMessage(content="Find latest NLP papers")])
|
214
|
+
|
215
|
+
result = s2_app.invoke(
|
216
|
+
state,
|
217
|
+
config={
|
218
|
+
"configurable": {
|
219
|
+
"thread_id": thread_id,
|
220
|
+
"checkpoint_ns": "test_ns",
|
221
|
+
"checkpoint_id": "test_checkpoint",
|
222
|
+
}
|
223
|
+
},
|
224
|
+
)
|
225
|
+
|
226
|
+
assert "messages" in result
|
227
|
+
last_message = result["messages"][-1].content.lower()
|
228
|
+
|
229
|
+
# Update assertions to match actual error message
|
230
|
+
assert any(
|
231
|
+
[
|
232
|
+
"unable to retrieve" in last_message,
|
233
|
+
"connection issue" in last_message,
|
234
|
+
"please try again later" in last_message,
|
235
|
+
]
|
236
|
+
)
|
237
|
+
assert "nlp papers" in last_message # Verify context is maintained
|
@@ -0,0 +1,180 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for main agent functionality.
|
3
|
+
Tests the supervisor agent's routing logic and state management.
|
4
|
+
"""
|
5
|
+
|
6
|
+
# pylint: disable=redefined-outer-name
|
7
|
+
from unittest.mock import Mock, patch, MagicMock
|
8
|
+
import pytest
|
9
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
10
|
+
from ..agents.main_agent import make_supervisor_node, get_app
|
11
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture(autouse=True)
|
15
|
+
def mock_hydra():
|
16
|
+
"""Mock Hydra configuration."""
|
17
|
+
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
18
|
+
cfg_mock = MagicMock()
|
19
|
+
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
20
|
+
cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test prompt"
|
21
|
+
mock_compose.return_value = cfg_mock
|
22
|
+
yield mock_compose
|
23
|
+
|
24
|
+
|
25
|
+
def test_get_app():
|
26
|
+
"""Test the full initialization of the LangGraph application."""
|
27
|
+
thread_id = "test_thread"
|
28
|
+
llm_model = "gpt-4o-mini"
|
29
|
+
|
30
|
+
# Mock the LLM
|
31
|
+
mock_llm = Mock()
|
32
|
+
|
33
|
+
with patch(
|
34
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
|
35
|
+
return_value=mock_llm,
|
36
|
+
):
|
37
|
+
|
38
|
+
app = get_app(thread_id, llm_model)
|
39
|
+
assert app is not None
|
40
|
+
assert "supervisor" in app.nodes
|
41
|
+
assert "s2_agent" in app.nodes # Ensure nodes exist
|
42
|
+
|
43
|
+
|
44
|
+
def test_supervisor_node_execution():
|
45
|
+
"""Test that the supervisor node processes messages and makes a decision."""
|
46
|
+
mock_llm = Mock()
|
47
|
+
thread_id = "test_thread"
|
48
|
+
|
49
|
+
# Mock the supervisor agent's response
|
50
|
+
mock_supervisor = Mock()
|
51
|
+
mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
|
52
|
+
|
53
|
+
with patch(
|
54
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent",
|
55
|
+
return_value=mock_supervisor,
|
56
|
+
):
|
57
|
+
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
58
|
+
|
59
|
+
# Create a mock state
|
60
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
61
|
+
|
62
|
+
# Execute
|
63
|
+
result = supervisor_node(mock_state)
|
64
|
+
|
65
|
+
# Validate
|
66
|
+
assert result.goto == "s2_agent"
|
67
|
+
mock_supervisor.invoke.assert_called_once_with(
|
68
|
+
mock_state, {"configurable": {"thread_id": thread_id}}
|
69
|
+
) # Ensure invoke was called correctly
|
70
|
+
|
71
|
+
|
72
|
+
def test_call_s2_agent():
|
73
|
+
"""Test the call to S2 agent and its integration with the state."""
|
74
|
+
thread_id = "test_thread"
|
75
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
76
|
+
|
77
|
+
with patch("aiagents4pharma.talk2scholars.agents.s2_agent.get_app") as mock_s2_app:
|
78
|
+
mock_s2_app.return_value.invoke.return_value = {
|
79
|
+
"messages": [AIMessage(content="Here are the papers")],
|
80
|
+
"papers": {"id123": "Sample Paper"},
|
81
|
+
}
|
82
|
+
|
83
|
+
app = get_app(thread_id)
|
84
|
+
result = app.invoke(
|
85
|
+
mock_state,
|
86
|
+
config={
|
87
|
+
"configurable": {
|
88
|
+
"thread_id": thread_id,
|
89
|
+
"checkpoint_ns": "test_ns",
|
90
|
+
"checkpoint_id": "test_checkpoint",
|
91
|
+
}
|
92
|
+
},
|
93
|
+
)
|
94
|
+
|
95
|
+
assert "messages" in result
|
96
|
+
assert "papers" in result
|
97
|
+
assert result["papers"]["id123"] == "Sample Paper"
|
98
|
+
|
99
|
+
mock_s2_app.return_value.invoke.assert_called_once()
|
100
|
+
|
101
|
+
|
102
|
+
def test_hydra_failure():
|
103
|
+
"""Test exception handling when Hydra fails to load config."""
|
104
|
+
thread_id = "test_thread"
|
105
|
+
with patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
106
|
+
with pytest.raises(Exception) as exc_info:
|
107
|
+
get_app(thread_id)
|
108
|
+
|
109
|
+
assert "Hydra error" in str(exc_info.value)
|
110
|
+
|
111
|
+
|
112
|
+
class TestMainAgent:
|
113
|
+
"""Basic tests for the main agent initialization and configuration"""
|
114
|
+
|
115
|
+
def test_supervisor_node_creation(self, mock_hydra):
|
116
|
+
"""Test that supervisor node can be created with correct config"""
|
117
|
+
mock_llm = Mock()
|
118
|
+
thread_id = "test_thread"
|
119
|
+
|
120
|
+
with patch(
|
121
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
122
|
+
) as mock_create:
|
123
|
+
mock_create.return_value = Mock()
|
124
|
+
supervisor = make_supervisor_node(mock_llm, thread_id)
|
125
|
+
|
126
|
+
assert supervisor is not None
|
127
|
+
assert mock_create.called
|
128
|
+
# Verify Hydra was called with correct parameters
|
129
|
+
assert mock_hydra.call_count == 1 # Updated assertion
|
130
|
+
|
131
|
+
def test_supervisor_config_loading(self, mock_hydra):
|
132
|
+
"""Test that supervisor loads configuration correctly"""
|
133
|
+
mock_llm = Mock()
|
134
|
+
thread_id = "test_thread"
|
135
|
+
|
136
|
+
with patch(
|
137
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
138
|
+
):
|
139
|
+
make_supervisor_node(mock_llm, thread_id)
|
140
|
+
|
141
|
+
# Verify Hydra initialization
|
142
|
+
assert mock_hydra.call_count == 1
|
143
|
+
assert "agents/talk2scholars/main_agent=default" in str(
|
144
|
+
mock_hydra.call_args_list[0]
|
145
|
+
)
|
146
|
+
|
147
|
+
def test_react_agent_params(self):
|
148
|
+
"""Test that react agent is created with correct parameters"""
|
149
|
+
mock_llm = Mock()
|
150
|
+
thread_id = "test_thread"
|
151
|
+
|
152
|
+
with patch(
|
153
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
154
|
+
) as mock_create:
|
155
|
+
mock_create.return_value = Mock()
|
156
|
+
make_supervisor_node(mock_llm, thread_id)
|
157
|
+
|
158
|
+
# Verify create_react_agent was called
|
159
|
+
assert mock_create.called
|
160
|
+
|
161
|
+
# Verify the parameters
|
162
|
+
args, kwargs = mock_create.call_args
|
163
|
+
assert args[0] == mock_llm # First argument should be the LLM
|
164
|
+
assert "state_schema" in kwargs # Should have state_schema
|
165
|
+
assert hasattr(
|
166
|
+
mock_create.return_value, "invoke"
|
167
|
+
) # Should have invoke method
|
168
|
+
|
169
|
+
def test_supervisor_custom_config(self, mock_hydra):
|
170
|
+
"""Test supervisor with custom configuration"""
|
171
|
+
mock_llm = Mock()
|
172
|
+
thread_id = "test_thread"
|
173
|
+
|
174
|
+
with patch(
|
175
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
176
|
+
):
|
177
|
+
make_supervisor_node(mock_llm, thread_id)
|
178
|
+
|
179
|
+
# Verify Hydra was called
|
180
|
+
mock_hydra.assert_called_once()
|
@@ -0,0 +1,138 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for the S2 agent (Semantic Scholar sub-agent).
|
3
|
+
"""
|
4
|
+
|
5
|
+
from unittest import mock
|
6
|
+
import pytest
|
7
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
8
|
+
from ..agents.s2_agent import get_app
|
9
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
10
|
+
|
11
|
+
|
12
|
+
@pytest.fixture(autouse=True)
|
13
|
+
def mock_hydra_fixture():
|
14
|
+
"""Mock Hydra configuration to prevent external dependencies."""
|
15
|
+
with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
|
16
|
+
cfg_mock = mock.MagicMock()
|
17
|
+
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
|
18
|
+
cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test prompt"
|
19
|
+
mock_compose.return_value = cfg_mock
|
20
|
+
yield mock_compose
|
21
|
+
|
22
|
+
|
23
|
+
@pytest.fixture
|
24
|
+
def mock_tools_fixture():
|
25
|
+
"""Mock tools to prevent execution of real API calls."""
|
26
|
+
with (
|
27
|
+
mock.patch(
|
28
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
|
29
|
+
) as mock_s2_search,
|
30
|
+
mock.patch(
|
31
|
+
"aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
|
32
|
+
) as mock_s2_display,
|
33
|
+
mock.patch(
|
34
|
+
"aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
|
35
|
+
"get_single_paper_recommendations"
|
36
|
+
) as mock_s2_single_rec,
|
37
|
+
mock.patch(
|
38
|
+
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
|
39
|
+
"get_multi_paper_recommendations"
|
40
|
+
) as mock_s2_multi_rec,
|
41
|
+
):
|
42
|
+
|
43
|
+
mock_s2_search.return_value = mock.Mock()
|
44
|
+
mock_s2_display.return_value = mock.Mock()
|
45
|
+
mock_s2_single_rec.return_value = mock.Mock()
|
46
|
+
mock_s2_multi_rec.return_value = mock.Mock()
|
47
|
+
|
48
|
+
yield [mock_s2_search, mock_s2_display, mock_s2_single_rec, mock_s2_multi_rec]
|
49
|
+
|
50
|
+
|
51
|
+
@pytest.mark.usefixtures("mock_hydra_fixture")
|
52
|
+
def test_s2_agent_initialization():
|
53
|
+
"""Test that S2 agent initializes correctly with mock configuration."""
|
54
|
+
thread_id = "test_thread"
|
55
|
+
|
56
|
+
with mock.patch(
|
57
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
58
|
+
) as mock_create:
|
59
|
+
mock_create.return_value = mock.Mock()
|
60
|
+
|
61
|
+
app = get_app(thread_id)
|
62
|
+
|
63
|
+
assert app is not None
|
64
|
+
assert mock_create.called
|
65
|
+
|
66
|
+
|
67
|
+
def test_s2_agent_invocation():
|
68
|
+
"""Test that the S2 agent processes user input and returns a valid response."""
|
69
|
+
thread_id = "test_thread"
|
70
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
71
|
+
|
72
|
+
with mock.patch(
|
73
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
74
|
+
) as mock_create:
|
75
|
+
mock_agent = mock.Mock()
|
76
|
+
mock_create.return_value = mock_agent
|
77
|
+
mock_agent.invoke.return_value = {
|
78
|
+
"messages": [AIMessage(content="Here are some AI papers")],
|
79
|
+
"papers": {"id123": "AI Research Paper"},
|
80
|
+
}
|
81
|
+
|
82
|
+
app = get_app(thread_id)
|
83
|
+
result = app.invoke(
|
84
|
+
mock_state,
|
85
|
+
config={
|
86
|
+
"configurable": {
|
87
|
+
"thread_id": thread_id,
|
88
|
+
"checkpoint_ns": "test_ns",
|
89
|
+
"checkpoint_id": "test_checkpoint",
|
90
|
+
}
|
91
|
+
},
|
92
|
+
)
|
93
|
+
|
94
|
+
assert "messages" in result
|
95
|
+
assert "papers" in result
|
96
|
+
assert result["papers"]["id123"] == "AI Research Paper"
|
97
|
+
|
98
|
+
|
99
|
+
def test_s2_agent_tools_assignment(request):
|
100
|
+
"""Ensure that the correct tools are assigned to the agent."""
|
101
|
+
thread_id = "test_thread"
|
102
|
+
|
103
|
+
# Dynamically retrieve the fixture to avoid redefining it in the function signature
|
104
|
+
mock_tools = request.getfixturevalue("mock_tools_fixture")
|
105
|
+
|
106
|
+
with (
|
107
|
+
mock.patch(
|
108
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
109
|
+
) as mock_create,
|
110
|
+
mock.patch(
|
111
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
|
112
|
+
) as mock_toolnode,
|
113
|
+
):
|
114
|
+
|
115
|
+
mock_agent = mock.Mock()
|
116
|
+
mock_create.return_value = mock_agent
|
117
|
+
|
118
|
+
# Mock ToolNode behavior
|
119
|
+
mock_tool_instance = mock.Mock()
|
120
|
+
mock_tool_instance.tools = mock_tools # Use the dynamically retrieved fixture
|
121
|
+
mock_toolnode.return_value = mock_tool_instance
|
122
|
+
|
123
|
+
get_app(thread_id)
|
124
|
+
|
125
|
+
# Ensure the agent was created with the mocked ToolNode
|
126
|
+
assert mock_toolnode.called
|
127
|
+
assert len(mock_tool_instance.tools) == 4
|
128
|
+
|
129
|
+
|
130
|
+
def test_s2_agent_hydra_failure():
|
131
|
+
"""Test exception handling when Hydra fails to load config."""
|
132
|
+
thread_id = "test_thread"
|
133
|
+
|
134
|
+
with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
135
|
+
with pytest.raises(Exception) as exc_info:
|
136
|
+
get_app(thread_id)
|
137
|
+
|
138
|
+
assert "Hydra error" in str(exc_info.value)
|