aiagents4pharma 1.19.1__py3-none-any.whl → 1.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,237 @@
1
+ """
2
+ Integration tests for talk2scholars system.
3
+
4
+ These tests ensure that:
5
+ 1. The main agent and sub-agent work together.
6
+ 2. The agents correctly interact with tools (search, recommendations).
7
+ 3. The full pipeline processes queries and updates state correctly.
8
+ """
9
+
10
+ # pylint: disable=redefined-outer-name
11
+ from unittest.mock import patch, Mock
12
+ import pytest
13
+ from langchain_core.messages import HumanMessage
14
+ from ..agents.main_agent import get_app as get_main_app
15
+ from ..agents.s2_agent import get_app as get_s2_app
16
+ from ..state.state_talk2scholars import Talk2Scholars
17
+
18
+
19
+ @pytest.fixture(autouse=True)
20
+ def mock_hydra():
21
+ """Mock Hydra configuration to prevent external dependencies."""
22
+ with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
23
+ cfg_mock = Mock()
24
+ cfg_mock.agents.talk2scholars.main_agent.temperature = 0
25
+ cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test main agent prompt"
26
+ cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
27
+ cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test s2 agent prompt"
28
+ mock_compose.return_value = cfg_mock
29
+ yield mock_compose
30
+
31
+
32
+ @pytest.fixture(autouse=True)
33
+ def mock_tools():
34
+ """Mock tools to prevent execution of real API calls."""
35
+ with (
36
+ patch(
37
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
38
+ ) as mock_s2_search,
39
+ patch(
40
+ "aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
41
+ ) as mock_s2_display,
42
+ patch(
43
+ "aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
44
+ "get_single_paper_recommendations"
45
+ ) as mock_s2_single_rec,
46
+ patch(
47
+ "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
48
+ "get_multi_paper_recommendations"
49
+ ) as mock_s2_multi_rec,
50
+ ):
51
+
52
+ mock_s2_search.return_value = {"papers": {"id123": "Mock Paper"}}
53
+ mock_s2_display.return_value = "Displaying Mock Results"
54
+ mock_s2_single_rec.return_value = {"recommendations": ["Paper A", "Paper B"]}
55
+ mock_s2_multi_rec.return_value = {
56
+ "multi_recommendations": ["Paper X", "Paper Y"]
57
+ }
58
+
59
+ yield {
60
+ "search_tool": mock_s2_search,
61
+ "display_results": mock_s2_display,
62
+ "single_paper_rec": mock_s2_single_rec,
63
+ "multi_paper_rec": mock_s2_multi_rec,
64
+ }
65
+
66
+
67
+ def test_full_workflow():
68
+ """Test the full workflow from main agent to S2 agent."""
69
+ thread_id = "test_thread"
70
+ main_app = get_main_app(thread_id)
71
+
72
+ # Define expected mock response with the actual structure
73
+ expected_paper = {
74
+ "530a059cb48477ad1e3d4f8f4b153274c8997332": {
75
+ "Title": "Explainable Artificial Intelligence",
76
+ "Abstract": None,
77
+ "Citation Count": 5544,
78
+ "Year": "2024",
79
+ "URL": "https://example.com/paper",
80
+ }
81
+ }
82
+
83
+ # Mock the search tool instead of the app
84
+ with patch(
85
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
86
+ return_value={"papers": expected_paper},
87
+ ):
88
+ state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
89
+ result = main_app.invoke(
90
+ state,
91
+ config={
92
+ "configurable": {
93
+ "thread_id": thread_id,
94
+ "checkpoint_ns": "test_ns",
95
+ "checkpoint_id": "test_checkpoint",
96
+ }
97
+ },
98
+ )
99
+
100
+ # Check values
101
+ assert "papers" in result
102
+ assert "messages" in result
103
+ assert len(result["papers"]) > 0
104
+
105
+
106
+ def test_s2_agent_execution():
107
+ """Test if the S2 agent processes requests correctly and updates state."""
108
+ thread_id = "test_thread"
109
+ s2_app = get_s2_app(thread_id)
110
+
111
+ state = Talk2Scholars(messages=[HumanMessage(content="Get recommendations")])
112
+
113
+ result = s2_app.invoke(
114
+ state,
115
+ config={
116
+ "configurable": {
117
+ "thread_id": thread_id,
118
+ "checkpoint_ns": "test_ns",
119
+ "checkpoint_id": "test_checkpoint",
120
+ }
121
+ },
122
+ )
123
+
124
+ assert "messages" in result
125
+ assert "multi_papers" in result
126
+ assert result["multi_papers"] is not None
127
+
128
+
129
+ def test_tool_integration(mock_tools):
130
+ """Test if the tools interact correctly with the workflow."""
131
+ thread_id = "test_thread"
132
+ s2_app = get_s2_app(thread_id)
133
+
134
+ state = Talk2Scholars(
135
+ messages=[HumanMessage(content="Search for AI ethics papers")]
136
+ )
137
+
138
+ mock_paper_id = "11159bdb213aaa243916f42f576396d483ba474b"
139
+ mock_response = {
140
+ "papers": {
141
+ mock_paper_id: {
142
+ "Title": "Mock AI Ethics Paper",
143
+ "Abstract": "A study on AI ethics",
144
+ "Citation Count": 100,
145
+ "URL": "https://example.com/mock-paper",
146
+ }
147
+ }
148
+ }
149
+
150
+ # Update both the fixture mock and patch the actual tool
151
+ mock_tools["search_tool"].return_value = {"papers": mock_response["papers"]}
152
+
153
+ with patch(
154
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool",
155
+ return_value={"papers": mock_response["papers"]},
156
+ ):
157
+ result = s2_app.invoke(
158
+ state,
159
+ config={
160
+ "configurable": {
161
+ "thread_id": thread_id,
162
+ "checkpoint_ns": "test_ns",
163
+ "checkpoint_id": "test_checkpoint",
164
+ }
165
+ },
166
+ )
167
+
168
+ assert "papers" in result
169
+ assert len(result["papers"]) > 0 # Verify we have papers
170
+ assert isinstance(result["papers"], dict) # Verify it's a dictionary
171
+
172
+
173
+ def test_empty_query():
174
+ """Test how the system handles an empty query."""
175
+ thread_id = "test_thread"
176
+ main_app = get_main_app(thread_id)
177
+
178
+ state = Talk2Scholars(messages=[HumanMessage(content="")])
179
+
180
+ # Mock the s2_agent app
181
+ mock_s2_app = get_s2_app(thread_id)
182
+
183
+ with patch(
184
+ "aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
185
+ return_value=mock_s2_app,
186
+ ):
187
+ result = main_app.invoke(
188
+ state,
189
+ config={
190
+ "configurable": {
191
+ "thread_id": thread_id,
192
+ "checkpoint_ns": "test_ns",
193
+ "checkpoint_id": "test_checkpoint",
194
+ }
195
+ },
196
+ )
197
+
198
+ assert "messages" in result
199
+ last_message = result["messages"][-1].content.lower()
200
+ assert any(
201
+ phrase in last_message
202
+ for phrase in ["no valid input", "how can i assist", "please provide a query"]
203
+ )
204
+
205
+
206
+ def test_api_failure_handling():
207
+ """Test if the system gracefully handles an API failure."""
208
+ thread_id = "test_thread"
209
+ s2_app = get_s2_app(thread_id)
210
+
211
+ expected_error = "API Timeout: Connection failed"
212
+ with patch("requests.get", side_effect=Exception(expected_error)):
213
+ state = Talk2Scholars(messages=[HumanMessage(content="Find latest NLP papers")])
214
+
215
+ result = s2_app.invoke(
216
+ state,
217
+ config={
218
+ "configurable": {
219
+ "thread_id": thread_id,
220
+ "checkpoint_ns": "test_ns",
221
+ "checkpoint_id": "test_checkpoint",
222
+ }
223
+ },
224
+ )
225
+
226
+ assert "messages" in result
227
+ last_message = result["messages"][-1].content.lower()
228
+
229
+ # Update assertions to match actual error message
230
+ assert any(
231
+ [
232
+ "unable to retrieve" in last_message,
233
+ "connection issue" in last_message,
234
+ "please try again later" in last_message,
235
+ ]
236
+ )
237
+ assert "nlp papers" in last_message # Verify context is maintained
@@ -0,0 +1,180 @@
1
+ """
2
+ Unit tests for main agent functionality.
3
+ Tests the supervisor agent's routing logic and state management.
4
+ """
5
+
6
+ # pylint: disable=redefined-outer-name
7
+ from unittest.mock import Mock, patch, MagicMock
8
+ import pytest
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+ from ..agents.main_agent import make_supervisor_node, get_app
11
+ from ..state.state_talk2scholars import Talk2Scholars
12
+
13
+
14
+ @pytest.fixture(autouse=True)
15
+ def mock_hydra():
16
+ """Mock Hydra configuration."""
17
+ with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
18
+ cfg_mock = MagicMock()
19
+ cfg_mock.agents.talk2scholars.main_agent.temperature = 0
20
+ cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test prompt"
21
+ mock_compose.return_value = cfg_mock
22
+ yield mock_compose
23
+
24
+
25
+ def test_get_app():
26
+ """Test the full initialization of the LangGraph application."""
27
+ thread_id = "test_thread"
28
+ llm_model = "gpt-4o-mini"
29
+
30
+ # Mock the LLM
31
+ mock_llm = Mock()
32
+
33
+ with patch(
34
+ "aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
35
+ return_value=mock_llm,
36
+ ):
37
+
38
+ app = get_app(thread_id, llm_model)
39
+ assert app is not None
40
+ assert "supervisor" in app.nodes
41
+ assert "s2_agent" in app.nodes # Ensure nodes exist
42
+
43
+
44
+ def test_supervisor_node_execution():
45
+ """Test that the supervisor node processes messages and makes a decision."""
46
+ mock_llm = Mock()
47
+ thread_id = "test_thread"
48
+
49
+ # Mock the supervisor agent's response
50
+ mock_supervisor = Mock()
51
+ mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
52
+
53
+ with patch(
54
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent",
55
+ return_value=mock_supervisor,
56
+ ):
57
+ supervisor_node = make_supervisor_node(mock_llm, thread_id)
58
+
59
+ # Create a mock state
60
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
61
+
62
+ # Execute
63
+ result = supervisor_node(mock_state)
64
+
65
+ # Validate
66
+ assert result.goto == "s2_agent"
67
+ mock_supervisor.invoke.assert_called_once_with(
68
+ mock_state, {"configurable": {"thread_id": thread_id}}
69
+ ) # Ensure invoke was called correctly
70
+
71
+
72
+ def test_call_s2_agent():
73
+ """Test the call to S2 agent and its integration with the state."""
74
+ thread_id = "test_thread"
75
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
76
+
77
+ with patch("aiagents4pharma.talk2scholars.agents.s2_agent.get_app") as mock_s2_app:
78
+ mock_s2_app.return_value.invoke.return_value = {
79
+ "messages": [AIMessage(content="Here are the papers")],
80
+ "papers": {"id123": "Sample Paper"},
81
+ }
82
+
83
+ app = get_app(thread_id)
84
+ result = app.invoke(
85
+ mock_state,
86
+ config={
87
+ "configurable": {
88
+ "thread_id": thread_id,
89
+ "checkpoint_ns": "test_ns",
90
+ "checkpoint_id": "test_checkpoint",
91
+ }
92
+ },
93
+ )
94
+
95
+ assert "messages" in result
96
+ assert "papers" in result
97
+ assert result["papers"]["id123"] == "Sample Paper"
98
+
99
+ mock_s2_app.return_value.invoke.assert_called_once()
100
+
101
+
102
+ def test_hydra_failure():
103
+ """Test exception handling when Hydra fails to load config."""
104
+ thread_id = "test_thread"
105
+ with patch("hydra.initialize", side_effect=Exception("Hydra error")):
106
+ with pytest.raises(Exception) as exc_info:
107
+ get_app(thread_id)
108
+
109
+ assert "Hydra error" in str(exc_info.value)
110
+
111
+
112
+ class TestMainAgent:
113
+ """Basic tests for the main agent initialization and configuration"""
114
+
115
+ def test_supervisor_node_creation(self, mock_hydra):
116
+ """Test that supervisor node can be created with correct config"""
117
+ mock_llm = Mock()
118
+ thread_id = "test_thread"
119
+
120
+ with patch(
121
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
122
+ ) as mock_create:
123
+ mock_create.return_value = Mock()
124
+ supervisor = make_supervisor_node(mock_llm, thread_id)
125
+
126
+ assert supervisor is not None
127
+ assert mock_create.called
128
+ # Verify Hydra was called with correct parameters
129
+ assert mock_hydra.call_count == 1 # Updated assertion
130
+
131
+ def test_supervisor_config_loading(self, mock_hydra):
132
+ """Test that supervisor loads configuration correctly"""
133
+ mock_llm = Mock()
134
+ thread_id = "test_thread"
135
+
136
+ with patch(
137
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
138
+ ):
139
+ make_supervisor_node(mock_llm, thread_id)
140
+
141
+ # Verify Hydra initialization
142
+ assert mock_hydra.call_count == 1
143
+ assert "agents/talk2scholars/main_agent=default" in str(
144
+ mock_hydra.call_args_list[0]
145
+ )
146
+
147
+ def test_react_agent_params(self):
148
+ """Test that react agent is created with correct parameters"""
149
+ mock_llm = Mock()
150
+ thread_id = "test_thread"
151
+
152
+ with patch(
153
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
154
+ ) as mock_create:
155
+ mock_create.return_value = Mock()
156
+ make_supervisor_node(mock_llm, thread_id)
157
+
158
+ # Verify create_react_agent was called
159
+ assert mock_create.called
160
+
161
+ # Verify the parameters
162
+ args, kwargs = mock_create.call_args
163
+ assert args[0] == mock_llm # First argument should be the LLM
164
+ assert "state_schema" in kwargs # Should have state_schema
165
+ assert hasattr(
166
+ mock_create.return_value, "invoke"
167
+ ) # Should have invoke method
168
+
169
+ def test_supervisor_custom_config(self, mock_hydra):
170
+ """Test supervisor with custom configuration"""
171
+ mock_llm = Mock()
172
+ thread_id = "test_thread"
173
+
174
+ with patch(
175
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
176
+ ):
177
+ make_supervisor_node(mock_llm, thread_id)
178
+
179
+ # Verify Hydra was called
180
+ mock_hydra.assert_called_once()
@@ -0,0 +1,138 @@
1
+ """
2
+ Unit tests for the S2 agent (Semantic Scholar sub-agent).
3
+ """
4
+
5
+ from unittest import mock
6
+ import pytest
7
+ from langchain_core.messages import HumanMessage, AIMessage
8
+ from ..agents.s2_agent import get_app
9
+ from ..state.state_talk2scholars import Talk2Scholars
10
+
11
+
12
+ @pytest.fixture(autouse=True)
13
+ def mock_hydra_fixture():
14
+ """Mock Hydra configuration to prevent external dependencies."""
15
+ with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
16
+ cfg_mock = mock.MagicMock()
17
+ cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
18
+ cfg_mock.agents.talk2scholars.s2_agent.s2_agent = "Test prompt"
19
+ mock_compose.return_value = cfg_mock
20
+ yield mock_compose
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_tools_fixture():
25
+ """Mock tools to prevent execution of real API calls."""
26
+ with (
27
+ mock.patch(
28
+ "aiagents4pharma.talk2scholars.tools.s2.search.search_tool"
29
+ ) as mock_s2_search,
30
+ mock.patch(
31
+ "aiagents4pharma.talk2scholars.tools.s2.display_results.display_results"
32
+ ) as mock_s2_display,
33
+ mock.patch(
34
+ "aiagents4pharma.talk2scholars.tools.s2.single_paper_rec."
35
+ "get_single_paper_recommendations"
36
+ ) as mock_s2_single_rec,
37
+ mock.patch(
38
+ "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
39
+ "get_multi_paper_recommendations"
40
+ ) as mock_s2_multi_rec,
41
+ ):
42
+
43
+ mock_s2_search.return_value = mock.Mock()
44
+ mock_s2_display.return_value = mock.Mock()
45
+ mock_s2_single_rec.return_value = mock.Mock()
46
+ mock_s2_multi_rec.return_value = mock.Mock()
47
+
48
+ yield [mock_s2_search, mock_s2_display, mock_s2_single_rec, mock_s2_multi_rec]
49
+
50
+
51
+ @pytest.mark.usefixtures("mock_hydra_fixture")
52
+ def test_s2_agent_initialization():
53
+ """Test that S2 agent initializes correctly with mock configuration."""
54
+ thread_id = "test_thread"
55
+
56
+ with mock.patch(
57
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
58
+ ) as mock_create:
59
+ mock_create.return_value = mock.Mock()
60
+
61
+ app = get_app(thread_id)
62
+
63
+ assert app is not None
64
+ assert mock_create.called
65
+
66
+
67
+ def test_s2_agent_invocation():
68
+ """Test that the S2 agent processes user input and returns a valid response."""
69
+ thread_id = "test_thread"
70
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
71
+
72
+ with mock.patch(
73
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
74
+ ) as mock_create:
75
+ mock_agent = mock.Mock()
76
+ mock_create.return_value = mock_agent
77
+ mock_agent.invoke.return_value = {
78
+ "messages": [AIMessage(content="Here are some AI papers")],
79
+ "papers": {"id123": "AI Research Paper"},
80
+ }
81
+
82
+ app = get_app(thread_id)
83
+ result = app.invoke(
84
+ mock_state,
85
+ config={
86
+ "configurable": {
87
+ "thread_id": thread_id,
88
+ "checkpoint_ns": "test_ns",
89
+ "checkpoint_id": "test_checkpoint",
90
+ }
91
+ },
92
+ )
93
+
94
+ assert "messages" in result
95
+ assert "papers" in result
96
+ assert result["papers"]["id123"] == "AI Research Paper"
97
+
98
+
99
+ def test_s2_agent_tools_assignment(request):
100
+ """Ensure that the correct tools are assigned to the agent."""
101
+ thread_id = "test_thread"
102
+
103
+ # Dynamically retrieve the fixture to avoid redefining it in the function signature
104
+ mock_tools = request.getfixturevalue("mock_tools_fixture")
105
+
106
+ with (
107
+ mock.patch(
108
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
109
+ ) as mock_create,
110
+ mock.patch(
111
+ "aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
112
+ ) as mock_toolnode,
113
+ ):
114
+
115
+ mock_agent = mock.Mock()
116
+ mock_create.return_value = mock_agent
117
+
118
+ # Mock ToolNode behavior
119
+ mock_tool_instance = mock.Mock()
120
+ mock_tool_instance.tools = mock_tools # Use the dynamically retrieved fixture
121
+ mock_toolnode.return_value = mock_tool_instance
122
+
123
+ get_app(thread_id)
124
+
125
+ # Ensure the agent was created with the mocked ToolNode
126
+ assert mock_toolnode.called
127
+ assert len(mock_tool_instance.tools) == 4
128
+
129
+
130
+ def test_s2_agent_hydra_failure():
131
+ """Test exception handling when Hydra fails to load config."""
132
+ thread_id = "test_thread"
133
+
134
+ with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
135
+ with pytest.raises(Exception) as exc_info:
136
+ get_app(thread_id)
137
+
138
+ assert "Hydra error" in str(exc_info.value)