aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
- aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
- aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
- aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
- aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
- aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,58 @@
|
|
1
|
+
"""
|
2
|
+
Integration tests for talk2scholars system with OpenAI.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import pytest
|
7
|
+
import hydra
|
8
|
+
from langchain_openai import ChatOpenAI
|
9
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
10
|
+
from ..agents.main_agent import get_app
|
11
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
12
|
+
|
13
|
+
# pylint: disable=redefined-outer-name
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.mark.skipif(
|
17
|
+
not os.getenv("OPENAI_API_KEY"), reason="Requires OpenAI API key to run"
|
18
|
+
)
|
19
|
+
def test_main_agent_real_llm():
|
20
|
+
"""
|
21
|
+
Test that the main agent invokes S2 agent correctly
|
22
|
+
and updates the state with real LLM execution.
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Load Hydra Configuration EXACTLY like in main_agent.py
|
26
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
27
|
+
cfg = hydra.compose(
|
28
|
+
config_name="config", overrides=["agents/talk2scholars/main_agent=default"]
|
29
|
+
)
|
30
|
+
hydra_cfg = cfg.agents.talk2scholars.main_agent
|
31
|
+
|
32
|
+
assert hydra_cfg is not None, "Hydra config failed to load"
|
33
|
+
|
34
|
+
# Use the real OpenAI API (ensure env variable is set)
|
35
|
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=hydra_cfg.temperature)
|
36
|
+
|
37
|
+
# Initialize main agent workflow (WITH real Hydra config)
|
38
|
+
thread_id = "test_thread"
|
39
|
+
app = get_app(thread_id, llm)
|
40
|
+
|
41
|
+
# Provide an actual user query
|
42
|
+
initial_state = Talk2Scholars(
|
43
|
+
messages=[HumanMessage(content="Find AI papers on transformers")]
|
44
|
+
)
|
45
|
+
|
46
|
+
# Invoke the agent (triggers supervisor → s2_agent)
|
47
|
+
result = app.invoke(
|
48
|
+
initial_state,
|
49
|
+
{"configurable": {"config_id": thread_id, "thread_id": thread_id}},
|
50
|
+
)
|
51
|
+
|
52
|
+
# Assert that the supervisor routed correctly
|
53
|
+
assert "messages" in result, "Expected messages in response"
|
54
|
+
|
55
|
+
# Fix: Accept AIMessage as a valid response type
|
56
|
+
assert isinstance(
|
57
|
+
result["messages"][-1], (HumanMessage, AIMessage, str)
|
58
|
+
), "Last message should be a valid response"
|
@@ -4,10 +4,12 @@ Tests the supervisor agent's routing logic and state management.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
# pylint: disable=redefined-outer-name
|
7
|
+
# pylint: disable=redefined-outer-name,too-few-public-methods
|
7
8
|
from unittest.mock import Mock, patch, MagicMock
|
8
9
|
import pytest
|
9
10
|
from langchain_core.messages import HumanMessage, AIMessage
|
10
|
-
from
|
11
|
+
from langgraph.graph import END
|
12
|
+
from ..agents.main_agent import make_supervisor_node, get_hydra_config, get_app
|
11
13
|
from ..state.state_talk2scholars import Talk2Scholars
|
12
14
|
|
13
15
|
|
@@ -17,7 +19,8 @@ def mock_hydra():
|
|
17
19
|
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
18
20
|
cfg_mock = MagicMock()
|
19
21
|
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
20
|
-
cfg_mock.agents.talk2scholars.main_agent.
|
22
|
+
cfg_mock.agents.talk2scholars.main_agent.system_prompt = "System prompt"
|
23
|
+
cfg_mock.agents.talk2scholars.main_agent.router_prompt = "Router prompt"
|
21
24
|
mock_compose.return_value = cfg_mock
|
22
25
|
yield mock_compose
|
23
26
|
|
@@ -25,156 +28,129 @@ def mock_hydra():
|
|
25
28
|
def test_get_app():
|
26
29
|
"""Test the full initialization of the LangGraph application."""
|
27
30
|
thread_id = "test_thread"
|
28
|
-
llm_model = "gpt-4o-mini"
|
29
|
-
|
30
|
-
# Mock the LLM
|
31
31
|
mock_llm = Mock()
|
32
|
+
app = get_app(thread_id, mock_llm)
|
33
|
+
assert app is not None
|
34
|
+
assert "supervisor" in app.nodes
|
35
|
+
assert "s2_agent" in app.nodes # Ensure nodes exist
|
32
36
|
|
33
|
-
with patch(
|
34
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
|
35
|
-
return_value=mock_llm,
|
36
|
-
):
|
37
37
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
def test_get_app_with_default_llm():
|
39
|
+
"""Test app initialization with default LLM parameters."""
|
40
|
+
thread_id = "test_thread"
|
41
|
+
llm_mock = Mock()
|
42
|
+
|
43
|
+
# We need to explicitly pass the mock instead of patching, since the function uses
|
44
|
+
# ChatOpenAI as a default argument value which is evaluated at function definition time
|
45
|
+
app = get_app(thread_id, llm_mock)
|
46
|
+
assert app is not None
|
47
|
+
# We can only verify the app was created successfully
|
48
|
+
|
49
|
+
|
50
|
+
def test_get_hydra_config():
|
51
|
+
"""Test that Hydra configuration loads correctly."""
|
52
|
+
with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
|
53
|
+
cfg_mock = MagicMock()
|
54
|
+
cfg_mock.agents.talk2scholars.main_agent.temperature = 0
|
55
|
+
mock_compose.return_value = cfg_mock
|
56
|
+
cfg = get_hydra_config()
|
57
|
+
assert cfg is not None
|
58
|
+
assert cfg.temperature == 0
|
59
|
+
|
60
|
+
|
61
|
+
def test_hydra_failure():
|
62
|
+
"""Test exception handling when Hydra fails to load config."""
|
63
|
+
thread_id = "test_thread"
|
64
|
+
with patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
65
|
+
with pytest.raises(Exception) as exc_info:
|
66
|
+
get_app(thread_id)
|
67
|
+
assert "Hydra error" in str(exc_info.value)
|
42
68
|
|
43
69
|
|
44
70
|
def test_supervisor_node_execution():
|
45
|
-
"""Test that the supervisor node
|
71
|
+
"""Test that the supervisor node routes correctly."""
|
46
72
|
mock_llm = Mock()
|
47
73
|
thread_id = "test_thread"
|
48
74
|
|
49
|
-
|
50
|
-
|
51
|
-
mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
|
75
|
+
class MockRouter:
|
76
|
+
"""Mock router class."""
|
52
77
|
|
53
|
-
|
54
|
-
|
55
|
-
|
78
|
+
next = "s2_agent"
|
79
|
+
|
80
|
+
with (
|
81
|
+
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
82
|
+
patch.object(mock_llm, "invoke", return_value=MockRouter()),
|
56
83
|
):
|
57
84
|
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
58
|
-
|
59
|
-
# Create a mock state
|
60
85
|
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
61
|
-
|
62
|
-
# Execute
|
63
86
|
result = supervisor_node(mock_state)
|
64
|
-
|
65
|
-
# Validate
|
66
87
|
assert result.goto == "s2_agent"
|
67
|
-
mock_supervisor.invoke.assert_called_once_with(
|
68
|
-
mock_state, {"configurable": {"thread_id": thread_id}}
|
69
|
-
) # Ensure invoke was called correctly
|
70
88
|
|
71
89
|
|
72
|
-
def
|
73
|
-
"""Test
|
90
|
+
def test_supervisor_node_finish():
|
91
|
+
"""Test that supervisor node correctly handles FINISH case."""
|
92
|
+
mock_llm = Mock()
|
74
93
|
thread_id = "test_thread"
|
75
|
-
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
76
94
|
|
77
|
-
|
78
|
-
|
79
|
-
"messages": [AIMessage(content="Here are the papers")],
|
80
|
-
"papers": {"id123": "Sample Paper"},
|
81
|
-
}
|
95
|
+
class MockRouter:
|
96
|
+
"""Mock router class."""
|
82
97
|
|
83
|
-
|
84
|
-
result = app.invoke(
|
85
|
-
mock_state,
|
86
|
-
config={
|
87
|
-
"configurable": {
|
88
|
-
"thread_id": thread_id,
|
89
|
-
"checkpoint_ns": "test_ns",
|
90
|
-
"checkpoint_id": "test_checkpoint",
|
91
|
-
}
|
92
|
-
},
|
93
|
-
)
|
98
|
+
next = "FINISH"
|
94
99
|
|
95
|
-
|
96
|
-
|
97
|
-
assert result["papers"]["id123"] == "Sample Paper"
|
100
|
+
class MockAIResponse:
|
101
|
+
"""Mock AI response class."""
|
98
102
|
|
99
|
-
|
103
|
+
def __init__(self):
|
104
|
+
self.content = "Final AI Response"
|
100
105
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
assert "
|
106
|
+
with (
|
107
|
+
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
108
|
+
patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
|
109
|
+
):
|
110
|
+
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
111
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
|
112
|
+
result = supervisor_node(mock_state)
|
113
|
+
assert result.goto == END
|
114
|
+
assert "messages" in result.update
|
115
|
+
assert isinstance(result.update["messages"], AIMessage)
|
116
|
+
assert result.update["messages"].content == "Final AI Response"
|
110
117
|
|
111
118
|
|
112
|
-
|
113
|
-
"""
|
119
|
+
def test_call_s2_agent_failure_in_get_app():
|
120
|
+
"""Test handling failure when calling s2_agent.get_app()."""
|
121
|
+
thread_id = "test_thread"
|
122
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
114
123
|
|
115
|
-
|
116
|
-
""
|
117
|
-
|
118
|
-
|
124
|
+
with patch(
|
125
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
|
126
|
+
side_effect=Exception("S2 Agent Failure"),
|
127
|
+
):
|
128
|
+
with pytest.raises(Exception) as exc_info:
|
129
|
+
app = get_app(thread_id) # Get the compiled workflow
|
130
|
+
app.invoke(
|
131
|
+
mock_state,
|
132
|
+
{"configurable": {"config_id": thread_id, "thread_id": thread_id}},
|
133
|
+
)
|
119
134
|
|
120
|
-
|
121
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
122
|
-
) as mock_create:
|
123
|
-
mock_create.return_value = Mock()
|
124
|
-
supervisor = make_supervisor_node(mock_llm, thread_id)
|
135
|
+
assert "S2 Agent Failure" in str(exc_info.value)
|
125
136
|
|
126
|
-
assert supervisor is not None
|
127
|
-
assert mock_create.called
|
128
|
-
# Verify Hydra was called with correct parameters
|
129
|
-
assert mock_hydra.call_count == 1 # Updated assertion
|
130
137
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
138
|
+
def test_call_s2_agent_failure_in_invoke():
|
139
|
+
"""Test handling failure when invoking s2_agent app."""
|
140
|
+
thread_id = "test_thread"
|
141
|
+
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
135
142
|
|
136
|
-
|
137
|
-
|
138
|
-
):
|
139
|
-
make_supervisor_node(mock_llm, thread_id)
|
143
|
+
mock_app = Mock()
|
144
|
+
mock_app.invoke.side_effect = Exception("S2 Agent Invoke Failure")
|
140
145
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
146
|
+
with patch(
|
147
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.get_app", return_value=mock_app
|
148
|
+
):
|
149
|
+
with pytest.raises(Exception) as exc_info:
|
150
|
+
app = get_app(thread_id) # Get the compiled workflow
|
151
|
+
app.invoke(
|
152
|
+
mock_state,
|
153
|
+
{"configurable": {"config_id": thread_id, "thread_id": thread_id}},
|
145
154
|
)
|
146
155
|
|
147
|
-
|
148
|
-
"""Test that react agent is created with correct parameters"""
|
149
|
-
mock_llm = Mock()
|
150
|
-
thread_id = "test_thread"
|
151
|
-
|
152
|
-
with patch(
|
153
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
154
|
-
) as mock_create:
|
155
|
-
mock_create.return_value = Mock()
|
156
|
-
make_supervisor_node(mock_llm, thread_id)
|
157
|
-
|
158
|
-
# Verify create_react_agent was called
|
159
|
-
assert mock_create.called
|
160
|
-
|
161
|
-
# Verify the parameters
|
162
|
-
args, kwargs = mock_create.call_args
|
163
|
-
assert args[0] == mock_llm # First argument should be the LLM
|
164
|
-
assert "state_schema" in kwargs # Should have state_schema
|
165
|
-
assert hasattr(
|
166
|
-
mock_create.return_value, "invoke"
|
167
|
-
) # Should have invoke method
|
168
|
-
|
169
|
-
def test_supervisor_custom_config(self, mock_hydra):
|
170
|
-
"""Test supervisor with custom configuration"""
|
171
|
-
mock_llm = Mock()
|
172
|
-
thread_id = "test_thread"
|
173
|
-
|
174
|
-
with patch(
|
175
|
-
"aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
|
176
|
-
):
|
177
|
-
make_supervisor_node(mock_llm, thread_id)
|
178
|
-
|
179
|
-
# Verify Hydra was called
|
180
|
-
mock_hydra.assert_called_once()
|
156
|
+
assert "S2 Agent Invoke Failure" in str(exc_info.value)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
Unit
|
2
|
+
Updated Unit Tests for the S2 agent (Semantic Scholar sub-agent).
|
3
3
|
"""
|
4
4
|
|
5
|
+
# pylint: disable=redefined-outer-name
|
5
6
|
from unittest import mock
|
6
7
|
import pytest
|
7
8
|
from langchain_core.messages import HumanMessage, AIMessage
|
@@ -35,31 +36,42 @@ def mock_tools_fixture():
|
|
35
36
|
"get_single_paper_recommendations"
|
36
37
|
) as mock_s2_single_rec,
|
37
38
|
mock.patch(
|
38
|
-
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
|
39
|
-
"get_multi_paper_recommendations"
|
39
|
+
"aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec.get_multi_paper_recommendations"
|
40
40
|
) as mock_s2_multi_rec,
|
41
|
+
mock.patch(
|
42
|
+
"aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
|
43
|
+
) as mock_s2_query_results,
|
44
|
+
mock.patch(
|
45
|
+
"aiagents4pharma.talk2scholars.tools.s2.retrieve_semantic_scholar_paper_id."
|
46
|
+
"retrieve_semantic_scholar_paper_id"
|
47
|
+
) as mock_s2_retrieve_id,
|
41
48
|
):
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
+
mock_s2_search.return_value = {"result": "Mock Search Result"}
|
50
|
+
mock_s2_display.return_value = {"result": "Mock Display Result"}
|
51
|
+
mock_s2_single_rec.return_value = {"result": "Mock Single Rec"}
|
52
|
+
mock_s2_multi_rec.return_value = {"result": "Mock Multi Rec"}
|
53
|
+
mock_s2_query_results.return_value = {"result": "Mock Query Result"}
|
54
|
+
mock_s2_retrieve_id.return_value = {"paper_id": "MockPaper123"}
|
55
|
+
|
56
|
+
yield [
|
57
|
+
mock_s2_search,
|
58
|
+
mock_s2_display,
|
59
|
+
mock_s2_single_rec,
|
60
|
+
mock_s2_multi_rec,
|
61
|
+
mock_s2_query_results,
|
62
|
+
mock_s2_retrieve_id,
|
63
|
+
]
|
49
64
|
|
50
65
|
|
51
66
|
@pytest.mark.usefixtures("mock_hydra_fixture")
|
52
67
|
def test_s2_agent_initialization():
|
53
68
|
"""Test that S2 agent initializes correctly with mock configuration."""
|
54
69
|
thread_id = "test_thread"
|
55
|
-
|
56
70
|
with mock.patch(
|
57
71
|
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
58
72
|
) as mock_create:
|
59
73
|
mock_create.return_value = mock.Mock()
|
60
|
-
|
61
74
|
app = get_app(thread_id)
|
62
|
-
|
63
75
|
assert app is not None
|
64
76
|
assert mock_create.called
|
65
77
|
|
@@ -68,7 +80,6 @@ def test_s2_agent_invocation():
|
|
68
80
|
"""Test that the S2 agent processes user input and returns a valid response."""
|
69
81
|
thread_id = "test_thread"
|
70
82
|
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
71
|
-
|
72
83
|
with mock.patch(
|
73
84
|
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
74
85
|
) as mock_create:
|
@@ -78,7 +89,6 @@ def test_s2_agent_invocation():
|
|
78
89
|
"messages": [AIMessage(content="Here are some AI papers")],
|
79
90
|
"papers": {"id123": "AI Research Paper"},
|
80
91
|
}
|
81
|
-
|
82
92
|
app = get_app(thread_id)
|
83
93
|
result = app.invoke(
|
84
94
|
mock_state,
|
@@ -90,7 +100,6 @@ def test_s2_agent_invocation():
|
|
90
100
|
}
|
91
101
|
},
|
92
102
|
)
|
93
|
-
|
94
103
|
assert "messages" in result
|
95
104
|
assert "papers" in result
|
96
105
|
assert result["papers"]["id123"] == "AI Research Paper"
|
@@ -99,10 +108,7 @@ def test_s2_agent_invocation():
|
|
99
108
|
def test_s2_agent_tools_assignment(request):
|
100
109
|
"""Ensure that the correct tools are assigned to the agent."""
|
101
110
|
thread_id = "test_thread"
|
102
|
-
|
103
|
-
# Dynamically retrieve the fixture to avoid redefining it in the function signature
|
104
111
|
mock_tools = request.getfixturevalue("mock_tools_fixture")
|
105
|
-
|
106
112
|
with (
|
107
113
|
mock.patch(
|
108
114
|
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
@@ -111,28 +117,88 @@ def test_s2_agent_tools_assignment(request):
|
|
111
117
|
"aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
|
112
118
|
) as mock_toolnode,
|
113
119
|
):
|
114
|
-
|
115
120
|
mock_agent = mock.Mock()
|
116
121
|
mock_create.return_value = mock_agent
|
117
|
-
|
118
|
-
# Mock ToolNode behavior
|
119
122
|
mock_tool_instance = mock.Mock()
|
120
|
-
mock_tool_instance.tools = mock_tools
|
123
|
+
mock_tool_instance.tools = mock_tools
|
121
124
|
mock_toolnode.return_value = mock_tool_instance
|
122
|
-
|
123
125
|
get_app(thread_id)
|
124
|
-
|
125
|
-
# Ensure the agent was created with the mocked ToolNode
|
126
126
|
assert mock_toolnode.called
|
127
|
-
assert len(mock_tool_instance.tools) ==
|
127
|
+
assert len(mock_tool_instance.tools) == 6
|
128
|
+
|
129
|
+
|
130
|
+
def test_s2_query_results_tool():
|
131
|
+
"""Test if the query_results tool is correctly utilized by the agent."""
|
132
|
+
thread_id = "test_thread"
|
133
|
+
mock_state = Talk2Scholars(
|
134
|
+
messages=[HumanMessage(content="Query results for AI papers")]
|
135
|
+
)
|
136
|
+
with mock.patch(
|
137
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
138
|
+
) as mock_create:
|
139
|
+
mock_agent = mock.Mock()
|
140
|
+
mock_create.return_value = mock_agent
|
141
|
+
mock_agent.invoke.return_value = {
|
142
|
+
"messages": [HumanMessage(content="Query results for AI papers")],
|
143
|
+
"last_displayed_papers": {},
|
144
|
+
"papers": {
|
145
|
+
"query_results": "Mock Query Result"
|
146
|
+
}, # Ensure the expected key is inside 'papers'
|
147
|
+
"multi_papers": {},
|
148
|
+
}
|
149
|
+
app = get_app(thread_id)
|
150
|
+
result = app.invoke(
|
151
|
+
mock_state,
|
152
|
+
config={
|
153
|
+
"configurable": {
|
154
|
+
"thread_id": thread_id,
|
155
|
+
"checkpoint_ns": "test_ns",
|
156
|
+
"checkpoint_id": "test_checkpoint",
|
157
|
+
}
|
158
|
+
},
|
159
|
+
)
|
160
|
+
assert "query_results" in result["papers"]
|
161
|
+
assert mock_agent.invoke.called
|
162
|
+
|
163
|
+
|
164
|
+
def test_s2_retrieve_id_tool():
|
165
|
+
"""Test if the retrieve_semantic_scholar_paper_id tool is correctly utilized by the agent."""
|
166
|
+
thread_id = "test_thread"
|
167
|
+
mock_state = Talk2Scholars(
|
168
|
+
messages=[HumanMessage(content="Retrieve paper ID for AI research")]
|
169
|
+
)
|
170
|
+
with mock.patch(
|
171
|
+
"aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
|
172
|
+
) as mock_create:
|
173
|
+
mock_agent = mock.Mock()
|
174
|
+
mock_create.return_value = mock_agent
|
175
|
+
mock_agent.invoke.return_value = {
|
176
|
+
"messages": [HumanMessage(content="Retrieve paper ID for AI research")],
|
177
|
+
"last_displayed_papers": {},
|
178
|
+
"papers": {
|
179
|
+
"paper_id": "MockPaper123"
|
180
|
+
}, # Ensure 'paper_id' is inside 'papers'
|
181
|
+
"multi_papers": {},
|
182
|
+
}
|
183
|
+
app = get_app(thread_id)
|
184
|
+
result = app.invoke(
|
185
|
+
mock_state,
|
186
|
+
config={
|
187
|
+
"configurable": {
|
188
|
+
"thread_id": thread_id,
|
189
|
+
"checkpoint_ns": "test_ns",
|
190
|
+
"checkpoint_id": "test_checkpoint",
|
191
|
+
}
|
192
|
+
},
|
193
|
+
)
|
194
|
+
assert "paper_id" in result["papers"]
|
195
|
+
assert mock_agent.invoke.called
|
128
196
|
|
129
197
|
|
130
198
|
def test_s2_agent_hydra_failure():
|
131
199
|
"""Test exception handling when Hydra fails to load config."""
|
132
200
|
thread_id = "test_thread"
|
133
|
-
|
134
201
|
with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
|
135
202
|
with pytest.raises(Exception) as exc_info:
|
136
203
|
get_app(thread_id)
|
137
|
-
|
138
204
|
assert "Hydra error" in str(exc_info.value)
|