aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
- aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
- aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
- aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
- aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
- aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,58 @@
|
|
1
|
+
"""
|
2
|
+
Integration tests for talk2scholars system with OpenAI.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import pytest
|
7
|
+
import hydra
|
8
|
+
from langchain_openai import ChatOpenAI
|
9
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
10
|
+
from ..agents.main_agent import get_app
|
11
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
12
|
+
|
13
|
+
# pylint: disable=redefined-outer-name
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.mark.skipif(
|
17
|
+
not os.getenv("OPENAI_API_KEY"), reason="Requires OpenAI API key to run"
|
18
|
+
)
|
19
|
+
def test_main_agent_real_llm():
|
20
|
+
"""
|
21
|
+
Test that the main agent invokes S2 agent correctly
|
22
|
+
and updates the state with real LLM execution.
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Load Hydra Configuration EXACTLY like in main_agent.py
|
26
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
27
|
+
cfg = hydra.compose(
|
28
|
+
config_name="config", overrides=["agents/talk2scholars/main_agent=default"]
|
29
|
+
)
|
30
|
+
hydra_cfg = cfg.agents.talk2scholars.main_agent
|
31
|
+
|
32
|
+
assert hydra_cfg is not None, "Hydra config failed to load"
|
33
|
+
|
34
|
+
# Use the real OpenAI API (ensure env variable is set)
|
35
|
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=hydra_cfg.temperature)
|
36
|
+
|
37
|
+
# Initialize main agent workflow (WITH real Hydra config)
|
38
|
+
thread_id = "test_thread"
|
39
|
+
app = get_app(thread_id, llm)
|
40
|
+
|
41
|
+
# Provide an actual user query
|
42
|
+
initial_state = Talk2Scholars(
|
43
|
+
messages=[HumanMessage(content="Find AI papers on transformers")]
|
44
|
+
)
|
45
|
+
|
46
|
+
# Invoke the agent (triggers supervisor → s2_agent)
|
47
|
+
result = app.invoke(
|
48
|
+
initial_state,
|
49
|
+
{"configurable": {"config_id": thread_id, "thread_id": thread_id}},
|
50
|
+
)
|
51
|
+
|
52
|
+
# Assert that the supervisor routed correctly
|
53
|
+
assert "messages" in result, "Expected messages in response"
|
54
|
+
|
55
|
+
# Fix: Accept AIMessage as a valid response type
|
56
|
+
assert isinstance(
|
57
|
+
result["messages"][-1], (HumanMessage, AIMessage, str)
|
58
|
+
), "Last message should be a valid response"
|
@@ -4,10 +4,12 @@ Tests the supervisor agent's routing logic and state management.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
# pylint: disable=redefined-outer-name
|
7
|
+
# pylint: disable=redefined-outer-name,too-few-public-methods
|
7
8
|
from unittest.mock import Mock, patch, MagicMock
|
8
9
|
import pytest
|
9
10
|
from langchain_core.messages import HumanMessage, AIMessage
|
10
|
-
from
|
11
|
+
from langgraph.graph import END
|
12
|
+
from ..agents.main_agent import make_supervisor_node, get_hydra_config, get_app
|
11
13
|
from ..state.state_talk2scholars import Talk2Scholars
|
12
14
|
|
13
15
|
|
@@ -17,7 +19,8 @@ def mock_hydra():
|
|
17
19
|
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
18
20
|
cfg_mock = MagicMock()
|
19
21
|
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
20
|
-
cfg_mock.agents.talk2scholars.main_agent.
|
22
|
+
cfg_mock.agents.talk2scholars.main_agent.system_prompt = "System prompt"
|
23
|
+
cfg_mock.agents.talk2scholars.main_agent.router_prompt = "Router prompt"
|
21
24
|
mock_compose.return_value = cfg_mock
|
22
25
|
yield mock_compose
|
23
26
|
|
@@ -25,156 +28,129 @@ def mock_hydra():
|
|
25
28
|
def test_get_app():
|
26
29
|
"""Test the full initialization of the LangGraph application."""
|
27
30
|
thread_id = "test_thread"
|
28
|
-
llm_model = "gpt-4o-mini"
|
29
|
-
|
30
|
-
# Mock the LLM
|
31
31
|
mock_llm = Mock()
|
32
|
+
app = get_app(thread_id, mock_llm)
|
33
|
+
assert app is not None
|
34
|
+
assert "supervisor" in app.nodes
|
35
|
+
assert "s2_agent" in app.nodes # Ensure nodes exist
|
32
36
|
|
33
|
-
with patch(
|
34
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
|
35
|
-
return_value=mock_llm,
|
36
|
-
):
|
37
37
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
def test_get_app_with_default_llm():
|
39
|
+
"""Test app initialization with default LLM parameters."""
|
40
|
+
thread_id = "test_thread"
|
41
|
+
llm_mock = Mock()
|
42
|
+
|
43
|
+
# We need to explicitly pass the mock instead of patching, since the function uses
|
44
|
+
# ChatOpenAI as a default argument value which is evaluated at function definition time
|
45
|
+
app = get_app(thread_id, llm_mock)
|
46
|
+
assert app is not None
|
47
|
+
# We can only verify the app was created successfully
|
48
|
+
|
49
|
+
|
50
|
+
def test_get_hydra_config():
|
51
|
+
"""Test that Hydra configuration loads correctly."""
|
52
|
+
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
53
|
+
cfg_mock = MagicMock()
|
54
|
+
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
55
|
+
mock_compose.return_value = cfg_mock
|
56
|
+
cfg = get_hydra_config()
|
57
|
+
assert cfg is not None
|
58
|
+
assert cfg.temperature == 0
|
59
|
+
|
60
|
+
|
61
|
+
def test_hydra_failure():
|
62
|
+
"""Test exception handling when Hydra fails to load config."""
|
63
|
+
thread_id = "test_thread"
|
64
|
+
with patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
65
|
+
with pytest.raises(Exception) as exc_info:
|
66
|
+
get_app(thread_id)
|
67
|
+
assert "Hydra error" in str(exc_info.value)
|
42
68
|
|
43
69
|
|
44
70
|
def test_supervisor_node_execution():
|
45
|
-
"""Test that the supervisor node
|
71
|
+
"""Test that the supervisor node routes correctly."""
|
46
72
|
mock_llm = Mock()
|
47
73
|
thread_id = "test_thread"
|
48
74
|
|
49
|
-
|
50
|
-
|
51
|
-
mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
|
75
|
+
class MockRouter:
|
76
|
+
"""Mock router class."""
|
52
77
|
|
53
|
-
|
54
|
-
|
55
|
-
|
78
|
+
next = "s2_agent"
|
79
|
+
|
80
|
+
with (
|
81
|
+
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
82
|
+
patch.object(mock_llm, "invoke", return_value=MockRouter()),
|
56
83
|
):
|
57
84
|
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
58
|
-
|
59
|
-
# Create a mock state
|
60
85
|
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
61
|
-
|
62
|
-
# Execute
|
63
86
|
result = supervisor_node(mock_state)
|
64
|
-
|
65
|
-
# Validate
|
66
87
|
assert result.goto == "s2_agent"
|
67
|
-
mock_supervisor.invoke.assert_called_once_with(
|
68
|
-
mock_state, {"configurable": {"thread_id": thread_id}}
|
69
|
-
) # Ensure invoke was called correctly
|
70
88
|
|
71
89
|
|
72
|
-
def
|
73
|
-
"""Test
|
90
|
+
def test_supervisor_node_finish():
|
91
|
+
"""Test that supervisor node correctly handles FINISH case."""
|
92
|
+
mock_llm = Mock()
|
74
93
|
thread_id = "test_thread"
|
75
|
-
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
76
94
|
|
77
|
-
|
78
|
-
|
79
|
-
"messages": [AIMessage(content="Here are the papers")],
|
80
|
-
"papers": {"id123": "Sample Paper"},
|
81
|
-
}
|
95
|
+
class MockRouter:
|
96
|
+
"""Mock router class."""
|
82
97
|
|
83
|
-
|
84
|
-
result = app.invoke(
|
85
|
-
mock_state,
|
86
|
-
config={
|
87
|
-
"configurable": {
|
88
|
-
"thread_id": thread_id,
|
89
|
-
"checkpoint_ns": "test_ns",
|
90
|
-
"checkpoint_id": "test_checkpoint",
|
91
|
-
}
|
92
|
-
},
|
93
|
-
)
|
98
|
+
next = "FINISH"
|
94
99
|
|
95
|
-
|
96
|
-
|
97
|
-
assert result["papers"]["id123"] == "Sample Paper"
|
100
|
+
class MockAIResponse:
|
101
|
+
"""Mock AI response class."""
|
98
102
|
|
99
|
-
|
103
|
+
def __init__(self):
|
104
|
+
self.content = "Final AI Response"
|
100
105
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
assert "
|
106
|
+
with (
|
107
|
+
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
108
|
+
patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
|
109
|
+
):
|
110
|
+
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
111
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
|
112
|
+
result = supervisor_node(mock_state)
|
113
|
+
assert result.goto == END
|
114
|
+
assert "messages" in result.update
|
115
|
+
assert isinstance(result.update["messages"], AIMessage)
|
116
|
+
assert result.update["messages"].content == "Final AI Response"
|
110
117
|
|
111
118
|
|
112
|
-
|
113
|
-
"""
|
119
|
+
def test_call_s2_agent_failure_in_get_app():
|
120
|
+
"""Test handling failure when calling s2_agent.get_app()."""
|
121
|
+
thread_id = "test_thread"
|
122
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
114
123
|
|
115
|
-
|
116
|
-
""
|
117
|
-
|
118
|
-
|
124
|
+
with patch(
|
125
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
|
126
|
+
side_effect=Exception("S2 Agent Failure"),
|
127
|
+
):
|
128
|
+
with pytest.raises(Exception) as exc_info:
|
129
|
+
app = get_app(thread_id) # Get the compiled workflow
|
130
|
+
app.invoke(
|
131
|
+
mock_state,
|
132
|
+
{"configurable": {"config_id": thread_id, "thread_id": thread_id}},
|
133
|
+
)
|
119
134
|
|
120
|
-
|
121
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
122
|
-
) as mock_create:
|
123
|
-
mock_create.return_value = Mock()
|
124
|
-
supervisor = make_supervisor_node(mock_llm, thread_id)
|
135
|
+
assert "S2 Agent Failure" in str(exc_info.value)
|
125
136
|
|
126
|
-
assert supervisor is not None
|
127
|
-
assert mock_create.called
|
128
|
-
# Verify Hydra was called with correct parameters
|
129
|
-
assert mock_hydra.call_count == 1 # Updated assertion
|
130
137
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
138
|
+
def test_call_s2_agent_failure_in_invoke():
|
139
|
+
"""Test handling failure when invoking s2_agent app."""
|
140
|
+
thread_id = "test_thread"
|
141
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
135
142
|
|
136
|
-
|
137
|
-
|
138
|
-
):
|
139
|
-
make_supervisor_node(mock_llm, thread_id)
|
143
|
+
mock_app = Mock()
|
144
|
+
mock_app.invoke.side_effect = Exception("S2 Agent Invoke Failure")
|
140
145
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
146
|
+
with patch(
|
147
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.get_app", return_value=mock_app
|
148
|
+
):
|
149
|
+
with pytest.raises(Exception) as exc_info:
|
150
|
+
app = get_app(thread_id) # Get the compiled workflow
|
151
|
+
app.invoke(
|
152
|
+
mock_state,
|
153
|
+
{"configurable": {"config_id": thread_id, "thread_id": thread_id}},
|
145
154
|
)
|
146
155
|
|
147
|
-
|
148
|
-
"""Test that react agent is created with correct parameters"""
|
149
|
-
mock_llm = Mock()
|
150
|
-
thread_id = "test_thread"
|
151
|
-
|
152
|
-
with patch(
|
153
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
154
|
-
) as mock_create:
|
155
|
-
mock_create.return_value = Mock()
|
156
|
-
make_supervisor_node(mock_llm, thread_id)
|
157
|
-
|
158
|
-
# Verify create_react_agent was called
|
159
|
-
assert mock_create.called
|
160
|
-
|
161
|
-
# Verify the parameters
|
162
|
-
args, kwargs = mock_create.call_args
|
163
|
-
assert args[0] == mock_llm # First argument should be the LLM
|
164
|
-
assert "state_schema" in kwargs # Should have state_schema
|
165
|
-
assert hasattr(
|
166
|
-
mock_create.return_value, "invoke"
|
167
|
-
) # Should have invoke method
|
168
|
-
|
169
|
-
def test_supervisor_custom_config(self, mock_hydra):
|
170
|
-
"""Test supervisor with custom configuration"""
|
171
|
-
mock_llm = Mock()
|
172
|
-
thread_id = "test_thread"
|
173
|
-
|
174
|
-
with patch(
|
175
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
176
|
-
):
|
177
|
-
make_supervisor_node(mock_llm, thread_id)
|
178
|
-
|
179
|
-
# Verify Hydra was called
|
180
|
-
mock_hydra.assert_called_once()
|
156
|
+
assert "S2 Agent Invoke Failure" in str(exc_info.value)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
Unit
|
2
|
+
Updated Unit Tests for the S2 agent (Semantic Scholar sub-agent).
|
3
3
|
"""
|
4
4
|
|
5
|
+
# pylint: disable=redefined-outer-name
|
5
6
|
from unittest import mock
|
6
7
|
import pytest
|
7
8
|
from langchain_core.messages import HumanMessage, AIMessage
|
@@ -35,31 +36,42 @@ def mock_tools_fixture():
|
|
35
36
|
"get_single_paper_recommendations"
|
36
37
|
) as mock_s2_single_rec,
|
37
38
|
mock.patch(
|
38
|
-
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
|
39
|
-
"get_multi_paper_recommendations"
|
39
|
+
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec.get_multi_paper_recommendations"
|
40
40
|
) as mock_s2_multi_rec,
|
41
|
+
mock.patch(
|
42
|
+
"aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
|
43
|
+
) as mock_s2_query_results,
|
44
|
+
mock.patch(
|
45
|
+
"aiagents4pharma.talk2scholars.tools.s2.retrieve_semantic_scholar_paper_id."
|
46
|
+
"retrieve_semantic_scholar_paper_id"
|
47
|
+
) as mock_s2_retrieve_id,
|
41
48
|
):
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
+
mock_s2_search.return_value = {"result": "Mock Search Result"}
|
50
|
+
mock_s2_display.return_value = {"result": "Mock Display Result"}
|
51
|
+
mock_s2_single_rec.return_value = {"result": "Mock Single Rec"}
|
52
|
+
mock_s2_multi_rec.return_value = {"result": "Mock Multi Rec"}
|
53
|
+
mock_s2_query_results.return_value = {"result": "Mock Query Result"}
|
54
|
+
mock_s2_retrieve_id.return_value = {"paper_id": "MockPaper123"}
|
55
|
+
|
56
|
+
yield [
|
57
|
+
mock_s2_search,
|
58
|
+
mock_s2_display,
|
59
|
+
mock_s2_single_rec,
|
60
|
+
mock_s2_multi_rec,
|
61
|
+
mock_s2_query_results,
|
62
|
+
mock_s2_retrieve_id,
|
63
|
+
]
|
49
64
|
|
50
65
|
|
51
66
|
@pytest.mark.usefixtures("mock_hydra_fixture")
|
52
67
|
def test_s2_agent_initialization():
|
53
68
|
"""Test that S2 agent initializes correctly with mock configuration."""
|
54
69
|
thread_id = "test_thread"
|
55
|
-
|
56
70
|
with mock.patch(
|
57
71
|
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
58
72
|
) as mock_create:
|
59
73
|
mock_create.return_value = mock.Mock()
|
60
|
-
|
61
74
|
app = get_app(thread_id)
|
62
|
-
|
63
75
|
assert app is not None
|
64
76
|
assert mock_create.called
|
65
77
|
|
@@ -68,7 +80,6 @@ def test_s2_agent_invocation():
|
|
68
80
|
"""Test that the S2 agent processes user input and returns a valid response."""
|
69
81
|
thread_id = "test_thread"
|
70
82
|
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
71
|
-
|
72
83
|
with mock.patch(
|
73
84
|
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
74
85
|
) as mock_create:
|
@@ -78,7 +89,6 @@ def test_s2_agent_invocation():
|
|
78
89
|
"messages": [AIMessage(content="Here are some AI papers")],
|
79
90
|
"papers": {"id123": "AI Research Paper"},
|
80
91
|
}
|
81
|
-
|
82
92
|
app = get_app(thread_id)
|
83
93
|
result = app.invoke(
|
84
94
|
mock_state,
|
@@ -90,7 +100,6 @@ def test_s2_agent_invocation():
|
|
90
100
|
}
|
91
101
|
},
|
92
102
|
)
|
93
|
-
|
94
103
|
assert "messages" in result
|
95
104
|
assert "papers" in result
|
96
105
|
assert result["papers"]["id123"] == "AI Research Paper"
|
@@ -99,10 +108,7 @@ def test_s2_agent_invocation():
|
|
99
108
|
def test_s2_agent_tools_assignment(request):
|
100
109
|
"""Ensure that the correct tools are assigned to the agent."""
|
101
110
|
thread_id = "test_thread"
|
102
|
-
|
103
|
-
# Dynamically retrieve the fixture to avoid redefining it in the function signature
|
104
111
|
mock_tools = request.getfixturevalue("mock_tools_fixture")
|
105
|
-
|
106
112
|
with (
|
107
113
|
mock.patch(
|
108
114
|
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
@@ -111,28 +117,88 @@ def test_s2_agent_tools_assignment(request):
|
|
111
117
|
"aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
|
112
118
|
) as mock_toolnode,
|
113
119
|
):
|
114
|
-
|
115
120
|
mock_agent = mock.Mock()
|
116
121
|
mock_create.return_value = mock_agent
|
117
|
-
|
118
|
-
# Mock ToolNode behavior
|
119
122
|
mock_tool_instance = mock.Mock()
|
120
|
-
mock_tool_instance.tools = mock_tools
|
123
|
+
mock_tool_instance.tools = mock_tools
|
121
124
|
mock_toolnode.return_value = mock_tool_instance
|
122
|
-
|
123
125
|
get_app(thread_id)
|
124
|
-
|
125
|
-
# Ensure the agent was created with the mocked ToolNode
|
126
126
|
assert mock_toolnode.called
|
127
|
-
assert len(mock_tool_instance.tools) ==
|
127
|
+
assert len(mock_tool_instance.tools) == 6
|
128
|
+
|
129
|
+
|
130
|
+
def test_s2_query_results_tool():
|
131
|
+
"""Test if the query_results tool is correctly utilized by the agent."""
|
132
|
+
thread_id = "test_thread"
|
133
|
+
mock_state = Talk2Scholars(
|
134
|
+
messages=[HumanMessage(content="Query results for AI papers")]
|
135
|
+
)
|
136
|
+
with mock.patch(
|
137
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
138
|
+
) as mock_create:
|
139
|
+
mock_agent = mock.Mock()
|
140
|
+
mock_create.return_value = mock_agent
|
141
|
+
mock_agent.invoke.return_value = {
|
142
|
+
"messages": [HumanMessage(content="Query results for AI papers")],
|
143
|
+
"last_displayed_papers": {},
|
144
|
+
"papers": {
|
145
|
+
"query_results": "Mock Query Result"
|
146
|
+
}, # Ensure the expected key is inside 'papers'
|
147
|
+
"multi_papers": {},
|
148
|
+
}
|
149
|
+
app = get_app(thread_id)
|
150
|
+
result = app.invoke(
|
151
|
+
mock_state,
|
152
|
+
config={
|
153
|
+
"configurable": {
|
154
|
+
"thread_id": thread_id,
|
155
|
+
"checkpoint_ns": "test_ns",
|
156
|
+
"checkpoint_id": "test_checkpoint",
|
157
|
+
}
|
158
|
+
},
|
159
|
+
)
|
160
|
+
assert "query_results" in result["papers"]
|
161
|
+
assert mock_agent.invoke.called
|
162
|
+
|
163
|
+
|
164
|
+
def test_s2_retrieve_id_tool():
|
165
|
+
"""Test if the retrieve_semantic_scholar_paper_id tool is correctly utilized by the agent."""
|
166
|
+
thread_id = "test_thread"
|
167
|
+
mock_state = Talk2Scholars(
|
168
|
+
messages=[HumanMessage(content="Retrieve paper ID for AI research")]
|
169
|
+
)
|
170
|
+
with mock.patch(
|
171
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
172
|
+
) as mock_create:
|
173
|
+
mock_agent = mock.Mock()
|
174
|
+
mock_create.return_value = mock_agent
|
175
|
+
mock_agent.invoke.return_value = {
|
176
|
+
"messages": [HumanMessage(content="Retrieve paper ID for AI research")],
|
177
|
+
"last_displayed_papers": {},
|
178
|
+
"papers": {
|
179
|
+
"paper_id": "MockPaper123"
|
180
|
+
}, # Ensure 'paper_id' is inside 'papers'
|
181
|
+
"multi_papers": {},
|
182
|
+
}
|
183
|
+
app = get_app(thread_id)
|
184
|
+
result = app.invoke(
|
185
|
+
mock_state,
|
186
|
+
config={
|
187
|
+
"configurable": {
|
188
|
+
"thread_id": thread_id,
|
189
|
+
"checkpoint_ns": "test_ns",
|
190
|
+
"checkpoint_id": "test_checkpoint",
|
191
|
+
}
|
192
|
+
},
|
193
|
+
)
|
194
|
+
assert "paper_id" in result["papers"]
|
195
|
+
assert mock_agent.invoke.called
|
128
196
|
|
129
197
|
|
130
198
|
def test_s2_agent_hydra_failure():
|
131
199
|
"""Test exception handling when Hydra fails to load config."""
|
132
200
|
thread_id = "test_thread"
|
133
|
-
|
134
201
|
with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
135
202
|
with pytest.raises(Exception) as exc_info:
|
136
203
|
get_app(thread_id)
|
137
|
-
|
138
204
|
assert "Hydra error" in str(exc_info.value)
|