aiagents4pharma 1.19.1__py3-none-any.whl → 1.20.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,237 @@
1
+ """
2
+ Integration tests for talk2scholars system.
3
+
4
+ These tests ensure that:
5
+ 1. The main agent and sub-agent work together.
6
+ 2. The agents correctly interact with tools (search, recommendations).
7
+ 3. The full pipeline processes queries and updates state correctly.
8
+ """
9
+
10
+ # pylint: disable=redefined-outer-name
11
+ from unittest.mock import patch, Mock
12
+ import pytest
13
+ from langchain_core.messages import HumanMessage
14
+ from ..agents.main_agent import get_app as get_main_app
15
+ from ..agents.s2_agent import get_app as get_s2_app
16
+ from ..state.state_talk2scholars import Talk2Scholars
17
+
18
+
19
+ @pytest.fixture(autouse=True)
20
+ def mock_hydra():
21
+ """Mock Hydra configuration to prevent external dependencies."""
22
+ with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
23
+ cfg_mock = Mock()
24
+ cfg_mock.agents.talk2scholars.main_agent.temperature = 0
25
+ cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test main agent prompt"
26
+ cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
27
+ cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test s2 agent prompt"
28
+ mock_compose.return_value = cfg_mock
29
+ yield mock_compose
30
+
31
+
32
+ @pytest.fixture(autouse=True)
33
+ def mock_tools():
34
+ """Mock tools to prevent execution of real API calls."""
35
+ with (
36
+ patch(
37
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
38
+ ) as mock_s2_search,
39
+ patch(
40
+ "aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
41
+ ) as mock_s2_display,
42
+ patch(
43
+ "aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
44
+ "get_single_paper_recommendations"
45
+ ) as mock_s2_single_rec,
46
+ patch(
47
+ "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
48
+ "get_multi_paper_recommendations"
49
+ ) as mock_s2_multi_rec,
50
+ ):
51
+
52
+ mock_s2_search.return_value = {"papers": {"id123": "Mock Paper"}}
53
+ mock_s2_display.return_value = "Displaying Mock Results"
54
+ mock_s2_single_rec.return_value = {"recommendations": ["Paper A", "Paper B"]}
55
+ mock_s2_multi_rec.return_value = {
56
+ "multi_recommendations": ["Paper X", "Paper Y"]
57
+ }
58
+
59
+ yield {
60
+ "search_tool": mock_s2_search,
61
+ "display_results": mock_s2_display,
62
+ "single_paper_rec": mock_s2_single_rec,
63
+ "multi_paper_rec": mock_s2_multi_rec,
64
+ }
65
+
66
+
67
+ def test_full_workflow():
68
+ """Test the full workflow from main agent to S2 agent."""
69
+ thread_id = "test_thread"
70
+ main_app = get_main_app(thread_id)
71
+
72
+ # Define expected mock response with the actual structure
73
+ expected_paper = {
74
+ "530a059cb48477ad1e3d4f8f4b153274c8997332": {
75
+ "Title": "Explainable Artificial Intelligence",
76
+ "Abstract": None,
77
+ "Citation Count": 5544,
78
+ "Year": "2024",
79
+ "URL": "https://example.com/paper",
80
+ }
81
+ }
82
+
83
+ # Mock the search tool instead of the app
84
+ with patch(
85
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
86
+ return_value={"papers": expected_paper},
87
+ ):
88
+ state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
89
+ result = main_app.invoke(
90
+ state,
91
+ config={
92
+ "configurable": {
93
+ "thread_id": thread_id,
94
+ "checkpoint_ns": "test_ns",
95
+ "checkpoint_id": "test_checkpoint",
96
+ }
97
+ },
98
+ )
99
+
100
+ # Check values
101
+ assert "papers" in result
102
+ assert "messages" in result
103
+ assert len(result["papers"]) > 0
104
+
105
+
106
+ def test_s2_agent_execution():
107
+ """Test if the S2 agent processes requests correctly and updates state."""
108
+ thread_id = "test_thread"
109
+ s2_app = get_s2_app(thread_id)
110
+
111
+ state = Talk2Scholars(messages=[HumanMessage(content="Get recommendations")])
112
+
113
+ result = s2_app.invoke(
114
+ state,
115
+ config={
116
+ "configurable": {
117
+ "thread_id": thread_id,
118
+ "checkpoint_ns": "test_ns",
119
+ "checkpoint_id": "test_checkpoint",
120
+ }
121
+ },
122
+ )
123
+
124
+ assert "messages" in result
125
+ assert "multi_papers" in result
126
+ assert result["multi_papers"] is not None
127
+
128
+
129
+ def test_tool_integration(mock_tools):
130
+ """Test if the tools interact correctly with the workflow."""
131
+ thread_id = "test_thread"
132
+ s2_app = get_s2_app(thread_id)
133
+
134
+ state = Talk2Scholars(
135
+ messages=[HumanMessage(content="Search for AI ethics papers")]
136
+ )
137
+
138
+ mock_paper_id = "11159bdb213aaa243916f42f576396d483ba474b"
139
+ mock_response = {
140
+ "papers": {
141
+ mock_paper_id: {
142
+ "Title": "Mock AI Ethics Paper",
143
+ "Abstract": "A study on AI ethics",
144
+ "Citation Count": 100,
145
+ "URL": "https://example.com/mock-paper",
146
+ }
147
+ }
148
+ }
149
+
150
+ # Update both the fixture mock and patch the actual tool
151
+ mock_tools["search_tool"].return_value = {"papers": mock_response["papers"]}
152
+
153
+ with patch(
154
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
155
+ return_value={"papers": mock_response["papers"]},
156
+ ):
157
+ result = s2_app.invoke(
158
+ state,
159
+ config={
160
+ "configurable": {
161
+ "thread_id": thread_id,
162
+ "checkpoint_ns": "test_ns",
163
+ "checkpoint_id": "test_checkpoint",
164
+ }
165
+ },
166
+ )
167
+
168
+ assert "papers" in result
169
+ assert len(result["papers"]) > 0 # Verify we have papers
170
+ assert isinstance(result["papers"], dict) # Verify it's a dictionary
171
+
172
+
173
+ def test_empty_query():
174
+ """Test how the system handles an empty query."""
175
+ thread_id = "test_thread"
176
+ main_app = get_main_app(thread_id)
177
+
178
+ state = Talk2Scholars(messages=[HumanMessage(content="")])
179
+
180
+ # Mock the s2_agent app
181
+ mock_s2_app = get_s2_app(thread_id)
182
+
183
+ with patch(
184
+ "aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
185
+ return_value=mock_s2_app,
186
+ ):
187
+ result = main_app.invoke(
188
+ state,
189
+ config={
190
+ "configurable": {
191
+ "thread_id": thread_id,
192
+ "checkpoint_ns": "test_ns",
193
+ "checkpoint_id": "test_checkpoint",
194
+ }
195
+ },
196
+ )
197
+
198
+ assert "messages" in result
199
+ last_message = result["messages"][-1].content.lower()
200
+ assert any(
201
+ phrase in last_message
202
+ for phrase in ["no valid input", "how can i assist", "please provide a query"]
203
+ )
204
+
205
+
206
+ def test_api_failure_handling():
207
+ """Test if the system gracefully handles an API failure."""
208
+ thread_id = "test_thread"
209
+ s2_app = get_s2_app(thread_id)
210
+
211
+ expected_error = "API Timeout: Connection failed"
212
+ with patch("requests.get", side_effect=Exception(expected_error)):
213
+ state = Talk2Scholars(messages=[HumanMessage(content="Find latest NLP papers")])
214
+
215
+ result = s2_app.invoke(
216
+ state,
217
+ config={
218
+ "configurable": {
219
+ "thread_id": thread_id,
220
+ "checkpoint_ns": "test_ns",
221
+ "checkpoint_id": "test_checkpoint",
222
+ }
223
+ },
224
+ )
225
+
226
+ assert "messages" in result
227
+ last_message = result["messages"][-1].content.lower()
228
+
229
+ # Update assertions to match actual error message
230
+ assert any(
231
+ [
232
+ "unable to retrieve" in last_message,
233
+ "connection issue" in last_message,
234
+ "please try again later" in last_message,
235
+ ]
236
+ )
237
+ assert "nlp papers" in last_message # Verify context is maintained
@@ -0,0 +1,180 @@
1
+ """
2
+ Unit tests for main agent functionality.
3
+ Tests the supervisor agent's routing logic and state management.
4
+ """
5
+
6
+ # pylint: disable=redefined-outer-name
7
+ from unittest.mock import Mock, patch, MagicMock
8
+ import pytest
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+ from ..agents.main_agent import make_supervisor_node, get_app
11
+ from ..state.state_talk2scholars import Talk2Scholars
12
+
13
+
14
+ @pytest.fixture(autouse=True)
15
+ def mock_hydra():
16
+ """Mock Hydra configuration."""
17
+ with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
18
+ cfg_mock = MagicMock()
19
+ cfg_mock.agents.talk2scholars.main_agent.temperature = 0
20
+ cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test prompt"
21
+ mock_compose.return_value = cfg_mock
22
+ yield mock_compose
23
+
24
+
25
+ def test_get_app():
26
+ """Test the full initialization of the LangGraph application."""
27
+ thread_id = "test_thread"
28
+ llm_model = "gpt-4o-mini"
29
+
30
+ # Mock the LLM
31
+ mock_llm = Mock()
32
+
33
+ with patch(
34
+ "aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
35
+ return_value=mock_llm,
36
+ ):
37
+
38
+ app = get_app(thread_id, llm_model)
39
+ assert app is not None
40
+ assert "supervisor" in app.nodes
41
+ assert "s2_agent" in app.nodes # Ensure nodes exist
42
+
43
+
44
+ def test_supervisor_node_execution():
45
+ """Test that the supervisor node processes messages and makes a decision."""
46
+ mock_llm = Mock()
47
+ thread_id = "test_thread"
48
+
49
+ # Mock the supervisor agent's response
50
+ mock_supervisor = Mock()
51
+ mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
52
+
53
+ with patch(
54
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent",
55
+ return_value=mock_supervisor,
56
+ ):
57
+ supervisor_node = make_supervisor_node(mock_llm, thread_id)
58
+
59
+ # Create a mock state
60
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
61
+
62
+ # Execute
63
+ result = supervisor_node(mock_state)
64
+
65
+ # Validate
66
+ assert result.goto == "s2_agent"
67
+ mock_supervisor.invoke.assert_called_once_with(
68
+ mock_state, {"configurable": {"thread_id": thread_id}}
69
+ ) # Ensure invoke was called correctly
70
+
71
+
72
+ def test_call_s2_agent():
73
+ """Test the call to S2 agent and its integration with the state."""
74
+ thread_id = "test_thread"
75
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
76
+
77
+ with patch("aiagents4pharma.talk2scholars.agents.s2_agent.get_app") as mock_s2_app:
78
+ mock_s2_app.return_value.invoke.return_value = {
79
+ "messages": [AIMessage(content="Here are the papers")],
80
+ "papers": {"id123": "Sample Paper"},
81
+ }
82
+
83
+ app = get_app(thread_id)
84
+ result = app.invoke(
85
+ mock_state,
86
+ config={
87
+ "configurable": {
88
+ "thread_id": thread_id,
89
+ "checkpoint_ns": "test_ns",
90
+ "checkpoint_id": "test_checkpoint",
91
+ }
92
+ },
93
+ )
94
+
95
+ assert "messages" in result
96
+ assert "papers" in result
97
+ assert result["papers"]["id123"] == "Sample Paper"
98
+
99
+ mock_s2_app.return_value.invoke.assert_called_once()
100
+
101
+
102
+ def test_hydra_failure():
103
+ """Test exception handling when Hydra fails to load config."""
104
+ thread_id = "test_thread"
105
+ with patch("hydra.initialize", side_effect=Exception("Hydra error")):
106
+ with pytest.raises(Exception) as exc_info:
107
+ get_app(thread_id)
108
+
109
+ assert "Hydra error" in str(exc_info.value)
110
+
111
+
112
+ class TestMainAgent:
113
+ """Basic tests for the main agent initialization and configuration"""
114
+
115
+ def test_supervisor_node_creation(self, mock_hydra):
116
+ """Test that supervisor node can be created with correct config"""
117
+ mock_llm = Mock()
118
+ thread_id = "test_thread"
119
+
120
+ with patch(
121
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
122
+ ) as mock_create:
123
+ mock_create.return_value = Mock()
124
+ supervisor = make_supervisor_node(mock_llm, thread_id)
125
+
126
+ assert supervisor is not None
127
+ assert mock_create.called
128
+ # Verify Hydra was called with correct parameters
129
+ assert mock_hydra.call_count == 1 # Updated assertion
130
+
131
+ def test_supervisor_config_loading(self, mock_hydra):
132
+ """Test that supervisor loads configuration correctly"""
133
+ mock_llm = Mock()
134
+ thread_id = "test_thread"
135
+
136
+ with patch(
137
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
138
+ ):
139
+ make_supervisor_node(mock_llm, thread_id)
140
+
141
+ # Verify Hydra initialization
142
+ assert mock_hydra.call_count == 1
143
+ assert "agents/talk2scholars/main_agent=default" in str(
144
+ mock_hydra.call_args_list[0]
145
+ )
146
+
147
+ def test_react_agent_params(self):
148
+ """Test that react agent is created with correct parameters"""
149
+ mock_llm = Mock()
150
+ thread_id = "test_thread"
151
+
152
+ with patch(
153
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
154
+ ) as mock_create:
155
+ mock_create.return_value = Mock()
156
+ make_supervisor_node(mock_llm, thread_id)
157
+
158
+ # Verify create_react_agent was called
159
+ assert mock_create.called
160
+
161
+ # Verify the parameters
162
+ args, kwargs = mock_create.call_args
163
+ assert args[0] == mock_llm # First argument should be the LLM
164
+ assert "state_schema" in kwargs # Should have state_schema
165
+ assert hasattr(
166
+ mock_create.return_value, "invoke"
167
+ ) # Should have invoke method
168
+
169
+ def test_supervisor_custom_config(self, mock_hydra):
170
+ """Test supervisor with custom configuration"""
171
+ mock_llm = Mock()
172
+ thread_id = "test_thread"
173
+
174
+ with patch(
175
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
176
+ ):
177
+ make_supervisor_node(mock_llm, thread_id)
178
+
179
+ # Verify Hydra was called
180
+ mock_hydra.assert_called_once()
@@ -0,0 +1,138 @@
1
+ """
2
+ Unit tests for the S2 agent (Semantic Scholar sub-agent).
3
+ """
4
+
5
+ from unittest import mock
6
+ import pytest
7
+ from langchain_core.messages import HumanMessage, AIMessage
8
+ from ..agents.s2_agent import get_app
9
+ from ..state.state_talk2scholars import Talk2Scholars
10
+
11
+
12
+ @pytest.fixture(autouse=True)
13
+ def mock_hydra_fixture():
14
+ """Mock Hydra configuration to prevent external dependencies."""
15
+ with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
16
+ cfg_mock = mock.MagicMock()
17
+ cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
18
+ cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test prompt"
19
+ mock_compose.return_value = cfg_mock
20
+ yield mock_compose
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_tools_fixture():
25
+ """Mock tools to prevent execution of real API calls."""
26
+ with (
27
+ mock.patch(
28
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
29
+ ) as mock_s2_search,
30
+ mock.patch(
31
+ "aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
32
+ ) as mock_s2_display,
33
+ mock.patch(
34
+ "aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
35
+ "get_single_paper_recommendations"
36
+ ) as mock_s2_single_rec,
37
+ mock.patch(
38
+ "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
39
+ "get_multi_paper_recommendations"
40
+ ) as mock_s2_multi_rec,
41
+ ):
42
+
43
+ mock_s2_search.return_value = mock.Mock()
44
+ mock_s2_display.return_value = mock.Mock()
45
+ mock_s2_single_rec.return_value = mock.Mock()
46
+ mock_s2_multi_rec.return_value = mock.Mock()
47
+
48
+ yield [mock_s2_search, mock_s2_display, mock_s2_single_rec, mock_s2_multi_rec]
49
+
50
+
51
+ @pytest.mark.usefixtures("mock_hydra_fixture")
52
+ def test_s2_agent_initialization():
53
+ """Test that S2 agent initializes correctly with mock configuration."""
54
+ thread_id = "test_thread"
55
+
56
+ with mock.patch(
57
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
58
+ ) as mock_create:
59
+ mock_create.return_value = mock.Mock()
60
+
61
+ app = get_app(thread_id)
62
+
63
+ assert app is not None
64
+ assert mock_create.called
65
+
66
+
67
+ def test_s2_agent_invocation():
68
+ """Test that the S2 agent processes user input and returns a valid response."""
69
+ thread_id = "test_thread"
70
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
71
+
72
+ with mock.patch(
73
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
74
+ ) as mock_create:
75
+ mock_agent = mock.Mock()
76
+ mock_create.return_value = mock_agent
77
+ mock_agent.invoke.return_value = {
78
+ "messages": [AIMessage(content="Here are some AI papers")],
79
+ "papers": {"id123": "AI Research Paper"},
80
+ }
81
+
82
+ app = get_app(thread_id)
83
+ result = app.invoke(
84
+ mock_state,
85
+ config={
86
+ "configurable": {
87
+ "thread_id": thread_id,
88
+ "checkpoint_ns": "test_ns",
89
+ "checkpoint_id": "test_checkpoint",
90
+ }
91
+ },
92
+ )
93
+
94
+ assert "messages" in result
95
+ assert "papers" in result
96
+ assert result["papers"]["id123"] == "AI Research Paper"
97
+
98
+
99
+ def test_s2_agent_tools_assignment(request):
100
+ """Ensure that the correct tools are assigned to the agent."""
101
+ thread_id = "test_thread"
102
+
103
+ # Dynamically retrieve the fixture to avoid redefining it in the function signature
104
+ mock_tools = request.getfixturevalue("mock_tools_fixture")
105
+
106
+ with (
107
+ mock.patch(
108
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
109
+ ) as mock_create,
110
+ mock.patch(
111
+ "aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
112
+ ) as mock_toolnode,
113
+ ):
114
+
115
+ mock_agent = mock.Mock()
116
+ mock_create.return_value = mock_agent
117
+
118
+ # Mock ToolNode behavior
119
+ mock_tool_instance = mock.Mock()
120
+ mock_tool_instance.tools = mock_tools # Use the dynamically retrieved fixture
121
+ mock_toolnode.return_value = mock_tool_instance
122
+
123
+ get_app(thread_id)
124
+
125
+ # Ensure the agent was created with the mocked ToolNode
126
+ assert mock_toolnode.called
127
+ assert len(mock_tool_instance.tools) == 4
128
+
129
+
130
+ def test_s2_agent_hydra_failure():
131
+ """Test exception handling when Hydra fails to load config."""
132
+ thread_id = "test_thread"
133
+
134
+ with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
135
+ with pytest.raises(Exception) as exc_info:
136
+ get_app(thread_id)
137
+
138
+ assert "Hydra error" in str(exc_info.value)