aiagents4pharma 1.19.1__py3-none-any.whl → 1.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +129 -73
- aiagents4pharma/talk2scholars/agents/s2_agent.py +2 -1
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +10 -31
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +16 -60
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +9 -8
- aiagents4pharma/talk2scholars/tests/test_integration.py +237 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +180 -0
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +138 -0
- aiagents4pharma/talk2scholars/tests/{test_langgraph.py → test_s2_tools.py} +79 -151
- aiagents4pharma/talk2scholars/tests/test_state.py +14 -0
- aiagents4pharma/talk2scholars/tools/s2/display_results.py +33 -8
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +10 -23
- aiagents4pharma/talk2scholars/tools/s2/search.py +10 -29
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +4 -29
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/METADATA +19 -3
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/RECORD +20 -15
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.19.1.dist-info → aiagents4pharma-1.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
|
|
1
|
+
"""
|
2
|
+
Integration tests for talk2scholars system.
|
3
|
+
|
4
|
+
These tests ensure that:
|
5
|
+
1. The main agent and sub-agent work together.
|
6
|
+
2. The agents correctly interact with tools (search, recommendations).
|
7
|
+
3. The full pipeline processes queries and updates state correctly.
|
8
|
+
"""
|
9
|
+
|
10
|
+
# pylint: disable=redefined-outer-name
|
11
|
+
from unittest.mock import patch, Mock
|
12
|
+
import pytest
|
13
|
+
from langchain_core.messages import HumanMessage
|
14
|
+
from ..agents.main_agent import get_app as get_main_app
|
15
|
+
from ..agents.s2_agent import get_app as get_s2_app
|
16
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
17
|
+
|
18
|
+
|
19
|
+
@pytest.fixture(autouse=True)
|
20
|
+
def mock_hydra():
|
21
|
+
"""Mock Hydra configuration to prevent external dependencies."""
|
22
|
+
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
23
|
+
cfg_mock = Mock()
|
24
|
+
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
25
|
+
cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test main agent prompt"
|
26
|
+
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
|
27
|
+
cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test s2 agent prompt"
|
28
|
+
mock_compose.return_value = cfg_mock
|
29
|
+
yield mock_compose
|
30
|
+
|
31
|
+
|
32
|
+
@pytest.fixture(autouse=True)
|
33
|
+
def mock_tools():
|
34
|
+
"""Mock tools to prevent execution of real API calls."""
|
35
|
+
with (
|
36
|
+
patch(
|
37
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
|
38
|
+
) as mock_s2_search,
|
39
|
+
patch(
|
40
|
+
"aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
|
41
|
+
) as mock_s2_display,
|
42
|
+
patch(
|
43
|
+
"aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
|
44
|
+
"get_single_paper_recommendations"
|
45
|
+
) as mock_s2_single_rec,
|
46
|
+
patch(
|
47
|
+
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
|
48
|
+
"get_multi_paper_recommendations"
|
49
|
+
) as mock_s2_multi_rec,
|
50
|
+
):
|
51
|
+
|
52
|
+
mock_s2_search.return_value = {"papers": {"id123": "Mock Paper"}}
|
53
|
+
mock_s2_display.return_value = "Displaying Mock Results"
|
54
|
+
mock_s2_single_rec.return_value = {"recommendations": ["Paper A", "Paper B"]}
|
55
|
+
mock_s2_multi_rec.return_value = {
|
56
|
+
"multi_recommendations": ["Paper X", "Paper Y"]
|
57
|
+
}
|
58
|
+
|
59
|
+
yield {
|
60
|
+
"search_tool": mock_s2_search,
|
61
|
+
"display_results": mock_s2_display,
|
62
|
+
"single_paper_rec": mock_s2_single_rec,
|
63
|
+
"multi_paper_rec": mock_s2_multi_rec,
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
def test_full_workflow():
|
68
|
+
"""Test the full workflow from main agent to S2 agent."""
|
69
|
+
thread_id = "test_thread"
|
70
|
+
main_app = get_main_app(thread_id)
|
71
|
+
|
72
|
+
# Define expected mock response with the actual structure
|
73
|
+
expected_paper = {
|
74
|
+
"530a059cb48477ad1e3d4f8f4b153274c8997332": {
|
75
|
+
"Title": "Explainable Artificial Intelligence",
|
76
|
+
"Abstract": None,
|
77
|
+
"Citation Count": 5544,
|
78
|
+
"Year": "2024",
|
79
|
+
"URL": "https://example.com/paper",
|
80
|
+
}
|
81
|
+
}
|
82
|
+
|
83
|
+
# Mock the search tool instead of the app
|
84
|
+
with patch(
|
85
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
|
86
|
+
return_value={"papers": expected_paper},
|
87
|
+
):
|
88
|
+
state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
89
|
+
result = main_app.invoke(
|
90
|
+
state,
|
91
|
+
config={
|
92
|
+
"configurable": {
|
93
|
+
"thread_id": thread_id,
|
94
|
+
"checkpoint_ns": "test_ns",
|
95
|
+
"checkpoint_id": "test_checkpoint",
|
96
|
+
}
|
97
|
+
},
|
98
|
+
)
|
99
|
+
|
100
|
+
# Check values
|
101
|
+
assert "papers" in result
|
102
|
+
assert "messages" in result
|
103
|
+
assert len(result["papers"]) > 0
|
104
|
+
|
105
|
+
|
106
|
+
def test_s2_agent_execution():
|
107
|
+
"""Test if the S2 agent processes requests correctly and updates state."""
|
108
|
+
thread_id = "test_thread"
|
109
|
+
s2_app = get_s2_app(thread_id)
|
110
|
+
|
111
|
+
state = Talk2Scholars(messages=[HumanMessage(content="Get recommendations")])
|
112
|
+
|
113
|
+
result = s2_app.invoke(
|
114
|
+
state,
|
115
|
+
config={
|
116
|
+
"configurable": {
|
117
|
+
"thread_id": thread_id,
|
118
|
+
"checkpoint_ns": "test_ns",
|
119
|
+
"checkpoint_id": "test_checkpoint",
|
120
|
+
}
|
121
|
+
},
|
122
|
+
)
|
123
|
+
|
124
|
+
assert "messages" in result
|
125
|
+
assert "multi_papers" in result
|
126
|
+
assert result["multi_papers"] is not None
|
127
|
+
|
128
|
+
|
129
|
+
def test_tool_integration(mock_tools):
|
130
|
+
"""Test if the tools interact correctly with the workflow."""
|
131
|
+
thread_id = "test_thread"
|
132
|
+
s2_app = get_s2_app(thread_id)
|
133
|
+
|
134
|
+
state = Talk2Scholars(
|
135
|
+
messages=[HumanMessage(content="Search for AI ethics papers")]
|
136
|
+
)
|
137
|
+
|
138
|
+
mock_paper_id = "11159bdb213aaa243916f42f576396d483ba474b"
|
139
|
+
mock_response = {
|
140
|
+
"papers": {
|
141
|
+
mock_paper_id: {
|
142
|
+
"Title": "Mock AI Ethics Paper",
|
143
|
+
"Abstract": "A study on AI ethics",
|
144
|
+
"Citation Count": 100,
|
145
|
+
"URL": "https://example.com/mock-paper",
|
146
|
+
}
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
# Update both the fixture mock and patch the actual tool
|
151
|
+
mock_tools["search_tool"].return_value = {"papers": mock_response["papers"]}
|
152
|
+
|
153
|
+
with patch(
|
154
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
|
155
|
+
return_value={"papers": mock_response["papers"]},
|
156
|
+
):
|
157
|
+
result = s2_app.invoke(
|
158
|
+
state,
|
159
|
+
config={
|
160
|
+
"configurable": {
|
161
|
+
"thread_id": thread_id,
|
162
|
+
"checkpoint_ns": "test_ns",
|
163
|
+
"checkpoint_id": "test_checkpoint",
|
164
|
+
}
|
165
|
+
},
|
166
|
+
)
|
167
|
+
|
168
|
+
assert "papers" in result
|
169
|
+
assert len(result["papers"]) > 0 # Verify we have papers
|
170
|
+
assert isinstance(result["papers"], dict) # Verify it's a dictionary
|
171
|
+
|
172
|
+
|
173
|
+
def test_empty_query():
|
174
|
+
"""Test how the system handles an empty query."""
|
175
|
+
thread_id = "test_thread"
|
176
|
+
main_app = get_main_app(thread_id)
|
177
|
+
|
178
|
+
state = Talk2Scholars(messages=[HumanMessage(content="")])
|
179
|
+
|
180
|
+
# Mock the s2_agent app
|
181
|
+
mock_s2_app = get_s2_app(thread_id)
|
182
|
+
|
183
|
+
with patch(
|
184
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
|
185
|
+
return_value=mock_s2_app,
|
186
|
+
):
|
187
|
+
result = main_app.invoke(
|
188
|
+
state,
|
189
|
+
config={
|
190
|
+
"configurable": {
|
191
|
+
"thread_id": thread_id,
|
192
|
+
"checkpoint_ns": "test_ns",
|
193
|
+
"checkpoint_id": "test_checkpoint",
|
194
|
+
}
|
195
|
+
},
|
196
|
+
)
|
197
|
+
|
198
|
+
assert "messages" in result
|
199
|
+
last_message = result["messages"][-1].content.lower()
|
200
|
+
assert any(
|
201
|
+
phrase in last_message
|
202
|
+
for phrase in ["no valid input", "how can i assist", "please provide a query"]
|
203
|
+
)
|
204
|
+
|
205
|
+
|
206
|
+
def test_api_failure_handling():
|
207
|
+
"""Test if the system gracefully handles an API failure."""
|
208
|
+
thread_id = "test_thread"
|
209
|
+
s2_app = get_s2_app(thread_id)
|
210
|
+
|
211
|
+
expected_error = "API Timeout: Connection failed"
|
212
|
+
with patch("requests.get", side_effect=Exception(expected_error)):
|
213
|
+
state = Talk2Scholars(messages=[HumanMessage(content="Find latest NLP papers")])
|
214
|
+
|
215
|
+
result = s2_app.invoke(
|
216
|
+
state,
|
217
|
+
config={
|
218
|
+
"configurable": {
|
219
|
+
"thread_id": thread_id,
|
220
|
+
"checkpoint_ns": "test_ns",
|
221
|
+
"checkpoint_id": "test_checkpoint",
|
222
|
+
}
|
223
|
+
},
|
224
|
+
)
|
225
|
+
|
226
|
+
assert "messages" in result
|
227
|
+
last_message = result["messages"][-1].content.lower()
|
228
|
+
|
229
|
+
# Update assertions to match actual error message
|
230
|
+
assert any(
|
231
|
+
[
|
232
|
+
"unable to retrieve" in last_message,
|
233
|
+
"connection issue" in last_message,
|
234
|
+
"please try again later" in last_message,
|
235
|
+
]
|
236
|
+
)
|
237
|
+
assert "nlp papers" in last_message # Verify context is maintained
|
@@ -0,0 +1,180 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for main agent functionality.
|
3
|
+
Tests the supervisor agent's routing logic and state management.
|
4
|
+
"""
|
5
|
+
|
6
|
+
# pylint: disable=redefined-outer-name
|
7
|
+
from unittest.mock import Mock, patch, MagicMock
|
8
|
+
import pytest
|
9
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
10
|
+
from ..agents.main_agent import make_supervisor_node, get_app
|
11
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture(autouse=True)
|
15
|
+
def mock_hydra():
|
16
|
+
"""Mock Hydra configuration."""
|
17
|
+
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
18
|
+
cfg_mock = MagicMock()
|
19
|
+
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
20
|
+
cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test prompt"
|
21
|
+
mock_compose.return_value = cfg_mock
|
22
|
+
yield mock_compose
|
23
|
+
|
24
|
+
|
25
|
+
def test_get_app():
|
26
|
+
"""Test the full initialization of the LangGraph application."""
|
27
|
+
thread_id = "test_thread"
|
28
|
+
llm_model = "gpt-4o-mini"
|
29
|
+
|
30
|
+
# Mock the LLM
|
31
|
+
mock_llm = Mock()
|
32
|
+
|
33
|
+
with patch(
|
34
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
|
35
|
+
return_value=mock_llm,
|
36
|
+
):
|
37
|
+
|
38
|
+
app = get_app(thread_id, llm_model)
|
39
|
+
assert app is not None
|
40
|
+
assert "supervisor" in app.nodes
|
41
|
+
assert "s2_agent" in app.nodes # Ensure nodes exist
|
42
|
+
|
43
|
+
|
44
|
+
def test_supervisor_node_execution():
|
45
|
+
"""Test that the supervisor node processes messages and makes a decision."""
|
46
|
+
mock_llm = Mock()
|
47
|
+
thread_id = "test_thread"
|
48
|
+
|
49
|
+
# Mock the supervisor agent's response
|
50
|
+
mock_supervisor = Mock()
|
51
|
+
mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
|
52
|
+
|
53
|
+
with patch(
|
54
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent",
|
55
|
+
return_value=mock_supervisor,
|
56
|
+
):
|
57
|
+
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
58
|
+
|
59
|
+
# Create a mock state
|
60
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
61
|
+
|
62
|
+
# Execute
|
63
|
+
result = supervisor_node(mock_state)
|
64
|
+
|
65
|
+
# Validate
|
66
|
+
assert result.goto == "s2_agent"
|
67
|
+
mock_supervisor.invoke.assert_called_once_with(
|
68
|
+
mock_state, {"configurable": {"thread_id": thread_id}}
|
69
|
+
) # Ensure invoke was called correctly
|
70
|
+
|
71
|
+
|
72
|
+
def test_call_s2_agent():
|
73
|
+
"""Test the call to S2 agent and its integration with the state."""
|
74
|
+
thread_id = "test_thread"
|
75
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
76
|
+
|
77
|
+
with patch("aiagents4pharma.talk2scholars.agents.s2_agent.get_app") as mock_s2_app:
|
78
|
+
mock_s2_app.return_value.invoke.return_value = {
|
79
|
+
"messages": [AIMessage(content="Here are the papers")],
|
80
|
+
"papers": {"id123": "Sample Paper"},
|
81
|
+
}
|
82
|
+
|
83
|
+
app = get_app(thread_id)
|
84
|
+
result = app.invoke(
|
85
|
+
mock_state,
|
86
|
+
config={
|
87
|
+
"configurable": {
|
88
|
+
"thread_id": thread_id,
|
89
|
+
"checkpoint_ns": "test_ns",
|
90
|
+
"checkpoint_id": "test_checkpoint",
|
91
|
+
}
|
92
|
+
},
|
93
|
+
)
|
94
|
+
|
95
|
+
assert "messages" in result
|
96
|
+
assert "papers" in result
|
97
|
+
assert result["papers"]["id123"] == "Sample Paper"
|
98
|
+
|
99
|
+
mock_s2_app.return_value.invoke.assert_called_once()
|
100
|
+
|
101
|
+
|
102
|
+
def test_hydra_failure():
|
103
|
+
"""Test exception handling when Hydra fails to load config."""
|
104
|
+
thread_id = "test_thread"
|
105
|
+
with patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
106
|
+
with pytest.raises(Exception) as exc_info:
|
107
|
+
get_app(thread_id)
|
108
|
+
|
109
|
+
assert "Hydra error" in str(exc_info.value)
|
110
|
+
|
111
|
+
|
112
|
+
class TestMainAgent:
|
113
|
+
"""Basic tests for the main agent initialization and configuration"""
|
114
|
+
|
115
|
+
def test_supervisor_node_creation(self, mock_hydra):
|
116
|
+
"""Test that supervisor node can be created with correct config"""
|
117
|
+
mock_llm = Mock()
|
118
|
+
thread_id = "test_thread"
|
119
|
+
|
120
|
+
with patch(
|
121
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
122
|
+
) as mock_create:
|
123
|
+
mock_create.return_value = Mock()
|
124
|
+
supervisor = make_supervisor_node(mock_llm, thread_id)
|
125
|
+
|
126
|
+
assert supervisor is not None
|
127
|
+
assert mock_create.called
|
128
|
+
# Verify Hydra was called with correct parameters
|
129
|
+
assert mock_hydra.call_count == 1 # Updated assertion
|
130
|
+
|
131
|
+
def test_supervisor_config_loading(self, mock_hydra):
|
132
|
+
"""Test that supervisor loads configuration correctly"""
|
133
|
+
mock_llm = Mock()
|
134
|
+
thread_id = "test_thread"
|
135
|
+
|
136
|
+
with patch(
|
137
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
138
|
+
):
|
139
|
+
make_supervisor_node(mock_llm, thread_id)
|
140
|
+
|
141
|
+
# Verify Hydra initialization
|
142
|
+
assert mock_hydra.call_count == 1
|
143
|
+
assert "agents/talk2scholars/main_agent=default" in str(
|
144
|
+
mock_hydra.call_args_list[0]
|
145
|
+
)
|
146
|
+
|
147
|
+
def test_react_agent_params(self):
|
148
|
+
"""Test that react agent is created with correct parameters"""
|
149
|
+
mock_llm = Mock()
|
150
|
+
thread_id = "test_thread"
|
151
|
+
|
152
|
+
with patch(
|
153
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
154
|
+
) as mock_create:
|
155
|
+
mock_create.return_value = Mock()
|
156
|
+
make_supervisor_node(mock_llm, thread_id)
|
157
|
+
|
158
|
+
# Verify create_react_agent was called
|
159
|
+
assert mock_create.called
|
160
|
+
|
161
|
+
# Verify the parameters
|
162
|
+
args, kwargs = mock_create.call_args
|
163
|
+
assert args[0] == mock_llm # First argument should be the LLM
|
164
|
+
assert "state_schema" in kwargs # Should have state_schema
|
165
|
+
assert hasattr(
|
166
|
+
mock_create.return_value, "invoke"
|
167
|
+
) # Should have invoke method
|
168
|
+
|
169
|
+
def test_supervisor_custom_config(self, mock_hydra):
|
170
|
+
"""Test supervisor with custom configuration"""
|
171
|
+
mock_llm = Mock()
|
172
|
+
thread_id = "test_thread"
|
173
|
+
|
174
|
+
with patch(
|
175
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
176
|
+
):
|
177
|
+
make_supervisor_node(mock_llm, thread_id)
|
178
|
+
|
179
|
+
# Verify Hydra was called
|
180
|
+
mock_hydra.assert_called_once()
|
@@ -0,0 +1,138 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for the S2 agent (Semantic Scholar sub-agent).
|
3
|
+
"""
|
4
|
+
|
5
|
+
from unittest import mock
|
6
|
+
import pytest
|
7
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
8
|
+
from ..agents.s2_agent import get_app
|
9
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
10
|
+
|
11
|
+
|
12
|
+
@pytest.fixture(autouse=True)
|
13
|
+
def mock_hydra_fixture():
|
14
|
+
"""Mock Hydra configuration to prevent external dependencies."""
|
15
|
+
with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
|
16
|
+
cfg_mock = mock.MagicMock()
|
17
|
+
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
|
18
|
+
cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test prompt"
|
19
|
+
mock_compose.return_value = cfg_mock
|
20
|
+
yield mock_compose
|
21
|
+
|
22
|
+
|
23
|
+
@pytest.fixture
|
24
|
+
def mock_tools_fixture():
|
25
|
+
"""Mock tools to prevent execution of real API calls."""
|
26
|
+
with (
|
27
|
+
mock.patch(
|
28
|
+
"aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
|
29
|
+
) as mock_s2_search,
|
30
|
+
mock.patch(
|
31
|
+
"aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
|
32
|
+
) as mock_s2_display,
|
33
|
+
mock.patch(
|
34
|
+
"aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
|
35
|
+
"get_single_paper_recommendations"
|
36
|
+
) as mock_s2_single_rec,
|
37
|
+
mock.patch(
|
38
|
+
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
|
39
|
+
"get_multi_paper_recommendations"
|
40
|
+
) as mock_s2_multi_rec,
|
41
|
+
):
|
42
|
+
|
43
|
+
mock_s2_search.return_value = mock.Mock()
|
44
|
+
mock_s2_display.return_value = mock.Mock()
|
45
|
+
mock_s2_single_rec.return_value = mock.Mock()
|
46
|
+
mock_s2_multi_rec.return_value = mock.Mock()
|
47
|
+
|
48
|
+
yield [mock_s2_search, mock_s2_display, mock_s2_single_rec, mock_s2_multi_rec]
|
49
|
+
|
50
|
+
|
51
|
+
@pytest.mark.usefixtures("mock_hydra_fixture")
|
52
|
+
def test_s2_agent_initialization():
|
53
|
+
"""Test that S2 agent initializes correctly with mock configuration."""
|
54
|
+
thread_id = "test_thread"
|
55
|
+
|
56
|
+
with mock.patch(
|
57
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
58
|
+
) as mock_create:
|
59
|
+
mock_create.return_value = mock.Mock()
|
60
|
+
|
61
|
+
app = get_app(thread_id)
|
62
|
+
|
63
|
+
assert app is not None
|
64
|
+
assert mock_create.called
|
65
|
+
|
66
|
+
|
67
|
+
def test_s2_agent_invocation():
|
68
|
+
"""Test that the S2 agent processes user input and returns a valid response."""
|
69
|
+
thread_id = "test_thread"
|
70
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
71
|
+
|
72
|
+
with mock.patch(
|
73
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
74
|
+
) as mock_create:
|
75
|
+
mock_agent = mock.Mock()
|
76
|
+
mock_create.return_value = mock_agent
|
77
|
+
mock_agent.invoke.return_value = {
|
78
|
+
"messages": [AIMessage(content="Here are some AI papers")],
|
79
|
+
"papers": {"id123": "AI Research Paper"},
|
80
|
+
}
|
81
|
+
|
82
|
+
app = get_app(thread_id)
|
83
|
+
result = app.invoke(
|
84
|
+
mock_state,
|
85
|
+
config={
|
86
|
+
"configurable": {
|
87
|
+
"thread_id": thread_id,
|
88
|
+
"checkpoint_ns": "test_ns",
|
89
|
+
"checkpoint_id": "test_checkpoint",
|
90
|
+
}
|
91
|
+
},
|
92
|
+
)
|
93
|
+
|
94
|
+
assert "messages" in result
|
95
|
+
assert "papers" in result
|
96
|
+
assert result["papers"]["id123"] == "AI Research Paper"
|
97
|
+
|
98
|
+
|
99
|
+
def test_s2_agent_tools_assignment(request):
|
100
|
+
"""Ensure that the correct tools are assigned to the agent."""
|
101
|
+
thread_id = "test_thread"
|
102
|
+
|
103
|
+
# Dynamically retrieve the fixture to avoid redefining it in the function signature
|
104
|
+
mock_tools = request.getfixturevalue("mock_tools_fixture")
|
105
|
+
|
106
|
+
with (
|
107
|
+
mock.patch(
|
108
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
109
|
+
) as mock_create,
|
110
|
+
mock.patch(
|
111
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
|
112
|
+
) as mock_toolnode,
|
113
|
+
):
|
114
|
+
|
115
|
+
mock_agent = mock.Mock()
|
116
|
+
mock_create.return_value = mock_agent
|
117
|
+
|
118
|
+
# Mock ToolNode behavior
|
119
|
+
mock_tool_instance = mock.Mock()
|
120
|
+
mock_tool_instance.tools = mock_tools # Use the dynamically retrieved fixture
|
121
|
+
mock_toolnode.return_value = mock_tool_instance
|
122
|
+
|
123
|
+
get_app(thread_id)
|
124
|
+
|
125
|
+
# Ensure the agent was created with the mocked ToolNode
|
126
|
+
assert mock_toolnode.called
|
127
|
+
assert len(mock_tool_instance.tools) == 4
|
128
|
+
|
129
|
+
|
130
|
+
def test_s2_agent_hydra_failure():
|
131
|
+
"""Test exception handling when Hydra fails to load config."""
|
132
|
+
thread_id = "test_thread"
|
133
|
+
|
134
|
+
with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
135
|
+
with pytest.raises(Exception) as exc_info:
|
136
|
+
get_app(thread_id)
|
137
|
+
|
138
|
+
assert "Hydra error" in str(exc_info.value)
|