aiagents4pharma 1.28.0__py3-none-any.whl → 1.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/main_agent.py +35 -209
- aiagents4pharma/talk2scholars/agents/s2_agent.py +10 -6
- aiagents4pharma/talk2scholars/agents/zotero_agent.py +12 -6
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +2 -48
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +5 -28
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +5 -21
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/__init__.py +1 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +1 -1
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -1
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +42 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +186 -111
- aiagents4pharma/talk2scholars/tests/test_s2_display.py +74 -0
- aiagents4pharma/talk2scholars/tests/test_s2_multi.py +282 -0
- aiagents4pharma/talk2scholars/tests/test_s2_query.py +78 -0
- aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +65 -0
- aiagents4pharma/talk2scholars/tests/test_s2_search.py +266 -0
- aiagents4pharma/talk2scholars/tests/test_s2_single.py +274 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_path.py +57 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +412 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_write.py +626 -0
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +50 -34
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +8 -8
- aiagents4pharma/talk2scholars/tools/s2/search.py +36 -23
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +44 -38
- aiagents4pharma/talk2scholars/tools/zotero/__init__.py +2 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +5 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +63 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +64 -19
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +247 -0
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/METADATA +6 -5
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/RECORD +37 -28
- aiagents4pharma/talk2scholars/tests/test_call_s2.py +0 -100
- aiagents4pharma/talk2scholars/tests/test_call_zotero.py +0 -94
- aiagents4pharma/talk2scholars/tests/test_s2_tools.py +0 -355
- aiagents4pharma/talk2scholars/tests/test_zotero_tool.py +0 -171
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/top_level.txt +0 -0
@@ -3,119 +3,194 @@ Unit tests for main agent functionality.
|
|
3
3
|
Tests the supervisor agent's routing logic and state management.
|
4
4
|
"""
|
5
5
|
|
6
|
-
# pylint: disable=redefined-outer-name
|
7
6
|
# pylint: disable=redefined-outer-name,too-few-public-methods
|
8
|
-
|
9
|
-
from
|
7
|
+
|
8
|
+
from types import SimpleNamespace
|
10
9
|
import pytest
|
11
|
-
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
10
|
+
import hydra
|
11
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
12
|
+
from langchain_openai import ChatOpenAI
|
13
|
+
from pydantic import Field
|
14
|
+
from aiagents4pharma.talk2scholars.agents.main_agent import get_app
|
15
|
+
|
16
|
+
# --- Dummy LLM Implementation ---
|
17
|
+
|
18
|
+
|
19
|
+
class DummyLLM(BaseChatModel):
|
20
|
+
"""A dummy language model implementation for testing purposes."""
|
21
|
+
|
22
|
+
model_name: str = Field(...)
|
23
|
+
|
24
|
+
def _generate(self, prompt, stop=None):
|
25
|
+
"""Generate a response given a prompt."""
|
26
|
+
DummyLLM.called_prompt = prompt
|
27
|
+
return "dummy output"
|
28
|
+
|
29
|
+
@property
|
30
|
+
def _llm_type(self):
|
31
|
+
"""Return the type of the language model."""
|
32
|
+
return "dummy"
|
33
|
+
|
34
|
+
|
35
|
+
# --- Dummy Workflow and Sub-agent Functions ---
|
36
|
+
|
37
|
+
|
38
|
+
class DummyWorkflow:
|
39
|
+
"""A dummy workflow class that records arguments for verification."""
|
40
|
+
|
41
|
+
def __init__(self, supervisor_args=None):
|
42
|
+
"""Initialize the workflow with the given supervisor arguments."""
|
43
|
+
self.supervisor_args = supervisor_args or {}
|
44
|
+
self.checkpointer = None
|
45
|
+
self.name = None
|
46
|
+
|
47
|
+
def compile(self, checkpointer, name):
|
48
|
+
"""Compile the workflow with the given checkpointer and name."""
|
49
|
+
self.checkpointer = checkpointer
|
50
|
+
self.name = name
|
51
|
+
return self
|
52
|
+
|
53
|
+
|
54
|
+
def dummy_get_app_s2(uniq_id, llm_model):
|
55
|
+
"""Return a DummyWorkflow for the S2 agent."""
|
56
|
+
dummy_get_app_s2.called_uniq_id = uniq_id
|
57
|
+
dummy_get_app_s2.called_llm_model = llm_model
|
58
|
+
return DummyWorkflow(supervisor_args={"agent": "s2", "uniq_id": uniq_id})
|
59
|
+
|
60
|
+
|
61
|
+
def dummy_get_app_zotero(uniq_id, llm_model):
|
62
|
+
"""Return a DummyWorkflow for the Zotero agent."""
|
63
|
+
dummy_get_app_zotero.called_uniq_id = uniq_id
|
64
|
+
dummy_get_app_zotero.called_llm_model = llm_model
|
65
|
+
return DummyWorkflow(supervisor_args={"agent": "zotero", "uniq_id": uniq_id})
|
66
|
+
|
67
|
+
|
68
|
+
def dummy_create_supervisor(apps, model, state_schema, **kwargs):
|
69
|
+
"""Return a DummyWorkflow for the supervisor."""
|
70
|
+
dummy_create_supervisor.called_kwargs = kwargs
|
71
|
+
return DummyWorkflow(
|
72
|
+
supervisor_args={
|
73
|
+
"apps": apps,
|
74
|
+
"model": model,
|
75
|
+
"state_schema": state_schema,
|
76
|
+
**kwargs,
|
77
|
+
}
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
# --- Dummy Hydra Configuration Setup ---
|
82
|
+
|
83
|
+
|
84
|
+
class DummyHydraContext:
|
85
|
+
"""A dummy context manager for mocking Hydra's initialize and compose functions."""
|
86
|
+
|
87
|
+
def __enter__(self):
|
88
|
+
"""Return None when entering the context."""
|
89
|
+
return None
|
90
|
+
|
91
|
+
def __exit__(self, exc_type, exc_val, traceback):
|
92
|
+
"""Exit function that does nothing."""
|
93
|
+
return None
|
94
|
+
|
95
|
+
|
96
|
+
def dict_to_namespace(d):
|
97
|
+
"""Convert a dictionary to a SimpleNamespace object."""
|
98
|
+
return SimpleNamespace(
|
99
|
+
**{
|
100
|
+
key: dict_to_namespace(val) if isinstance(val, dict) else val
|
101
|
+
for key, val in d.items()
|
102
|
+
}
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
dummy_config = {
|
107
|
+
"agents": {
|
108
|
+
"talk2scholars": {"main_agent": {"system_prompt": "Dummy system prompt"}}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
class DummyHydraCompose:
|
114
|
+
"""A dummy class that returns a namespace from a dummy config dictionary."""
|
115
|
+
|
116
|
+
def __init__(self, config):
|
117
|
+
"""Constructor that stores the dummy config."""
|
118
|
+
self.config = config
|
119
|
+
|
120
|
+
def __getattr__(self, item):
|
121
|
+
"""Return a namespace from the dummy config."""
|
122
|
+
return dict_to_namespace(self.config.get(item, {}))
|
123
|
+
|
124
|
+
|
125
|
+
# --- Pytest Fixtures to Patch Dependencies ---
|
126
|
+
|
127
|
+
|
128
|
+
@pytest.fixture(autouse=True)
|
129
|
+
def patch_hydra(monkeypatch):
|
130
|
+
"""Patch the hydra.initialize and hydra.compose functions to return dummy objects."""
|
131
|
+
monkeypatch.setattr(
|
132
|
+
hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
|
133
|
+
)
|
134
|
+
monkeypatch.setattr(
|
135
|
+
hydra, "compose", lambda config_name, overrides: DummyHydraCompose(dummy_config)
|
136
|
+
)
|
15
137
|
|
16
138
|
|
17
139
|
@pytest.fixture(autouse=True)
|
18
|
-
def
|
19
|
-
"""
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
def
|
64
|
-
"""
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
"""
|
75
|
-
mock_llm = Mock()
|
76
|
-
thread_id = "test_thread"
|
77
|
-
|
78
|
-
class MockRouter:
|
79
|
-
"""Mock router class."""
|
80
|
-
|
81
|
-
next = random.choice(["s2_agent", "zotero_agent"])
|
82
|
-
|
83
|
-
with (
|
84
|
-
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
85
|
-
patch.object(mock_llm, "invoke", return_value=MockRouter()),
|
86
|
-
):
|
87
|
-
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
88
|
-
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
89
|
-
result = supervisor_node(mock_state)
|
90
|
-
|
91
|
-
# Accept either "s2_agent" or "zotero_agent"
|
92
|
-
assert result.goto in ["s2_agent", "zotero_agent"]
|
93
|
-
|
94
|
-
|
95
|
-
def test_supervisor_node_finish():
|
96
|
-
"""Test that supervisor node correctly handles FINISH case."""
|
97
|
-
mock_llm = Mock()
|
98
|
-
thread_id = "test_thread"
|
99
|
-
|
100
|
-
class MockRouter:
|
101
|
-
"""Mock router class."""
|
102
|
-
|
103
|
-
next = "FINISH"
|
104
|
-
|
105
|
-
class MockAIResponse:
|
106
|
-
"""Mock AI response class."""
|
107
|
-
|
108
|
-
def __init__(self):
|
109
|
-
self.content = "Final AI Response"
|
110
|
-
|
111
|
-
with (
|
112
|
-
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
113
|
-
patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
|
114
|
-
):
|
115
|
-
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
116
|
-
mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
|
117
|
-
result = supervisor_node(mock_state)
|
118
|
-
assert result.goto == END
|
119
|
-
assert "messages" in result.update
|
120
|
-
assert isinstance(result.update["messages"], AIMessage)
|
121
|
-
assert result.update["messages"].content == "Final AI Response"
|
140
|
+
def patch_sub_agents_and_supervisor(monkeypatch):
|
141
|
+
"""Patch the sub-agents and supervisor creation functions."""
|
142
|
+
monkeypatch.setattr(
|
143
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_s2", dummy_get_app_s2
|
144
|
+
)
|
145
|
+
monkeypatch.setattr(
|
146
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_zotero",
|
147
|
+
dummy_get_app_zotero,
|
148
|
+
)
|
149
|
+
monkeypatch.setattr(
|
150
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_supervisor",
|
151
|
+
dummy_create_supervisor,
|
152
|
+
)
|
153
|
+
|
154
|
+
|
155
|
+
# --- Test Cases ---
|
156
|
+
|
157
|
+
|
158
|
+
def test_dummy_llm_generate():
|
159
|
+
"""Test the dummy LLM's generate function."""
|
160
|
+
dummy = DummyLLM(model_name="test-model")
|
161
|
+
output = getattr(dummy, "_generate")("any prompt")
|
162
|
+
assert output == "dummy output"
|
163
|
+
|
164
|
+
|
165
|
+
def test_dummy_llm_llm_type():
|
166
|
+
"""Test the dummy LLM's _llm_type property."""
|
167
|
+
dummy = DummyLLM(model_name="test-model")
|
168
|
+
assert getattr(dummy, "_llm_type") == "dummy"
|
169
|
+
|
170
|
+
|
171
|
+
def test_get_app_with_gpt4o_mini():
|
172
|
+
"""
|
173
|
+
Test that get_app replaces a 'gpt-4o-mini' LLM with a new ChatOpenAI instance.
|
174
|
+
"""
|
175
|
+
uniq_id = "test_thread"
|
176
|
+
dummy_llm = DummyLLM(model_name="gpt-4o-mini")
|
177
|
+
app = get_app(uniq_id, dummy_llm)
|
178
|
+
|
179
|
+
supervisor_args = getattr(app, "supervisor_args", {})
|
180
|
+
assert isinstance(supervisor_args.get("model"), ChatOpenAI)
|
181
|
+
assert supervisor_args.get("prompt") == "Dummy system prompt"
|
182
|
+
assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
|
183
|
+
|
184
|
+
|
185
|
+
def test_get_app_with_other_model():
|
186
|
+
"""
|
187
|
+
Test that get_app does not replace the LLM if its model_name is not 'gpt-4o-mini'.
|
188
|
+
"""
|
189
|
+
uniq_id = "test_thread_2"
|
190
|
+
dummy_llm = DummyLLM(model_name="other-model")
|
191
|
+
app = get_app(uniq_id, dummy_llm)
|
192
|
+
|
193
|
+
supervisor_args = getattr(app, "supervisor_args", {})
|
194
|
+
assert supervisor_args.get("model") is dummy_llm
|
195
|
+
assert supervisor_args.get("prompt") == "Dummy system prompt"
|
196
|
+
assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
|
@@ -0,0 +1,74 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for S2 tools functionality.
|
3
|
+
"""
|
4
|
+
|
5
|
+
# pylint: disable=redefined-outer-name
|
6
|
+
import pytest
|
7
|
+
from langgraph.types import Command
|
8
|
+
from ..tools.s2.display_results import (
|
9
|
+
display_results,
|
10
|
+
NoPapersFoundError as raised_error,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture
|
15
|
+
def initial_state():
|
16
|
+
"""Provides an empty initial state for tests."""
|
17
|
+
return {"papers": {}, "multi_papers": {}}
|
18
|
+
|
19
|
+
|
20
|
+
# Fixed test data for deterministic results
|
21
|
+
MOCK_SEARCH_RESPONSE = {
|
22
|
+
"data": [
|
23
|
+
{
|
24
|
+
"paperId": "123",
|
25
|
+
"title": "Machine Learning Basics",
|
26
|
+
"abstract": "An introduction to ML",
|
27
|
+
"year": 2023,
|
28
|
+
"citationCount": 100,
|
29
|
+
"url": "https://example.com/paper1",
|
30
|
+
"authors": [{"name": "Test Author"}],
|
31
|
+
}
|
32
|
+
]
|
33
|
+
}
|
34
|
+
|
35
|
+
MOCK_STATE_PAPER = {
|
36
|
+
"123": {
|
37
|
+
"Title": "Machine Learning Basics",
|
38
|
+
"Abstract": "An introduction to ML",
|
39
|
+
"Year": 2023,
|
40
|
+
"Citation Count": 100,
|
41
|
+
"URL": "https://example.com/paper1",
|
42
|
+
}
|
43
|
+
}
|
44
|
+
|
45
|
+
|
46
|
+
class TestS2Tools:
|
47
|
+
"""Unit tests for individual S2 tools"""
|
48
|
+
|
49
|
+
def test_display_results_empty_state(self, initial_state):
|
50
|
+
"""Verifies display_results tool behavior when state is empty and raises an exception"""
|
51
|
+
with pytest.raises(
|
52
|
+
raised_error,
|
53
|
+
match="No papers found. A search/rec needs to be performed first.",
|
54
|
+
):
|
55
|
+
display_results.invoke({"state": initial_state, "tool_call_id": "test123"})
|
56
|
+
|
57
|
+
def test_display_results_shows_papers(self, initial_state):
|
58
|
+
"""Verifies display_results tool correctly returns papers from state"""
|
59
|
+
state = initial_state.copy()
|
60
|
+
state["last_displayed_papers"] = "papers"
|
61
|
+
state["papers"] = MOCK_STATE_PAPER
|
62
|
+
|
63
|
+
result = display_results.invoke(
|
64
|
+
input={"state": state, "tool_call_id": "test123"}
|
65
|
+
)
|
66
|
+
|
67
|
+
assert isinstance(result, Command) # Expect a Command object
|
68
|
+
assert isinstance(result.update, dict) # Ensure update is a dictionary
|
69
|
+
assert "messages" in result.update
|
70
|
+
assert len(result.update["messages"]) == 1
|
71
|
+
assert (
|
72
|
+
"1 papers found. Papers are attached as an artifact."
|
73
|
+
in result.update["messages"][0].content
|
74
|
+
)
|
@@ -0,0 +1,282 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for S2 tools functionality.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
from types import SimpleNamespace
|
7
|
+
import pytest
|
8
|
+
import requests
|
9
|
+
from langgraph.types import Command
|
10
|
+
from langchain_core.messages import ToolMessage
|
11
|
+
import hydra
|
12
|
+
from aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec import (
|
13
|
+
get_multi_paper_recommendations,
|
14
|
+
)
|
15
|
+
|
16
|
+
# --- Dummy Hydra Config Setup ---
|
17
|
+
|
18
|
+
|
19
|
+
class DummyHydraContext:
|
20
|
+
"""dummy context manager for mocking Hydra's initialize and compose functions."""
|
21
|
+
|
22
|
+
def __enter__(self):
|
23
|
+
"""enter function that returns None."""
|
24
|
+
return None
|
25
|
+
|
26
|
+
def __exit__(self, exc_type, exc_val, traceback):
|
27
|
+
"""exit function that does nothing."""
|
28
|
+
return None
|
29
|
+
|
30
|
+
|
31
|
+
# Create a dummy configuration that mimics the expected hydra config.
|
32
|
+
dummy_config = SimpleNamespace(
|
33
|
+
tools=SimpleNamespace(
|
34
|
+
multi_paper_recommendation=SimpleNamespace(
|
35
|
+
api_endpoint="http://dummy.endpoint/multi",
|
36
|
+
headers={"Content-Type": "application/json"},
|
37
|
+
api_fields=["paperId", "title", "authors"],
|
38
|
+
request_timeout=10,
|
39
|
+
)
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
# --- Dummy Response Classes and Functions for requests.post ---
|
44
|
+
|
45
|
+
|
46
|
+
class DummyResponse:
|
47
|
+
"""A dummy response class for mocking HTTP responses."""
|
48
|
+
|
49
|
+
def __init__(self, json_data, status_code=200):
|
50
|
+
"""Initialize a DummyResponse with the given JSON data and status code."""
|
51
|
+
self._json_data = json_data
|
52
|
+
self.status_code = status_code
|
53
|
+
|
54
|
+
def json(self):
|
55
|
+
"""Return the JSON data from the response."""
|
56
|
+
return self._json_data
|
57
|
+
|
58
|
+
def raise_for_status(self):
|
59
|
+
"""raise an HTTP error for status codes >= 400."""
|
60
|
+
if self.status_code >= 400:
|
61
|
+
raise requests.HTTPError("HTTP Error")
|
62
|
+
|
63
|
+
|
64
|
+
def test_dummy_response_no_error():
|
65
|
+
"""Test that raise_for_status does not raise an exception for a successful response."""
|
66
|
+
# Create a DummyResponse with a successful status code.
|
67
|
+
response = DummyResponse({"data": "success"}, status_code=200)
|
68
|
+
# Calling raise_for_status should not raise an exception and should return None.
|
69
|
+
assert response.raise_for_status() is None
|
70
|
+
|
71
|
+
|
72
|
+
def test_dummy_response_raise_error():
|
73
|
+
"""Test that raise_for_status raises an exception for a failing response."""
|
74
|
+
# Create a DummyResponse with a failing status code.
|
75
|
+
response = DummyResponse({"error": "fail"}, status_code=400)
|
76
|
+
# Calling raise_for_status should raise an HTTPError.
|
77
|
+
with pytest.raises(requests.HTTPError):
|
78
|
+
response.raise_for_status()
|
79
|
+
|
80
|
+
|
81
|
+
def dummy_requests_post_success(url, headers, params, data, timeout):
|
82
|
+
"""dummy_requests_post_success"""
|
83
|
+
# Record call parameters for assertions.
|
84
|
+
dummy_requests_post_success.called_url = url
|
85
|
+
dummy_requests_post_success.called_headers = headers
|
86
|
+
dummy_requests_post_success.called_params = params
|
87
|
+
dummy_requests_post_success.called_data = data
|
88
|
+
dummy_requests_post_success.called_timeout = timeout
|
89
|
+
|
90
|
+
# Simulate a valid API response with three recommended papers;
|
91
|
+
# one paper missing authors should be filtered out.
|
92
|
+
dummy_data = {
|
93
|
+
"recommendedPapers": [
|
94
|
+
{
|
95
|
+
"paperId": "paperA",
|
96
|
+
"title": "Multi Rec Paper A",
|
97
|
+
"authors": ["Author X"],
|
98
|
+
"year": 2019,
|
99
|
+
"citationCount": 12,
|
100
|
+
"url": "http://paperA",
|
101
|
+
"externalIds": {"ArXiv": "arxivA"},
|
102
|
+
},
|
103
|
+
{
|
104
|
+
"paperId": "paperB",
|
105
|
+
"title": "Multi Rec Paper B",
|
106
|
+
"authors": ["Author Y"],
|
107
|
+
"year": 2020,
|
108
|
+
"citationCount": 18,
|
109
|
+
"url": "http://paperB",
|
110
|
+
"externalIds": {},
|
111
|
+
},
|
112
|
+
{
|
113
|
+
"paperId": "paperC",
|
114
|
+
"title": "Multi Rec Paper C",
|
115
|
+
"authors": None, # This one should be filtered out.
|
116
|
+
"year": 2021,
|
117
|
+
"citationCount": 25,
|
118
|
+
"url": "http://paperC",
|
119
|
+
"externalIds": {"ArXiv": "arxivC"},
|
120
|
+
},
|
121
|
+
]
|
122
|
+
}
|
123
|
+
return DummyResponse(dummy_data)
|
124
|
+
|
125
|
+
|
126
|
+
def dummy_requests_post_unexpected(url, headers, params, data, timeout):
|
127
|
+
"""dummy_requests_post_unexpected"""
|
128
|
+
dummy_requests_post_unexpected.called_url = url
|
129
|
+
dummy_requests_post_unexpected.called_headers = headers
|
130
|
+
dummy_requests_post_unexpected.called_params = params
|
131
|
+
dummy_requests_post_unexpected.called_data = data
|
132
|
+
dummy_requests_post_unexpected.called_timeout = timeout
|
133
|
+
# Simulate a response missing the 'recommendedPapers' key.
|
134
|
+
return DummyResponse({"error": "Invalid format"})
|
135
|
+
|
136
|
+
|
137
|
+
def dummy_requests_post_no_recs(url, headers, params, data, timeout):
|
138
|
+
"""dummy_requests_post_no_recs"""
|
139
|
+
dummy_requests_post_no_recs.called_url = url
|
140
|
+
dummy_requests_post_no_recs.called_headers = headers
|
141
|
+
dummy_requests_post_no_recs.called_params = params
|
142
|
+
dummy_requests_post_no_recs.called_data = data
|
143
|
+
dummy_requests_post_no_recs.called_timeout = timeout
|
144
|
+
# Simulate a response with an empty recommendations list.
|
145
|
+
return DummyResponse({"recommendedPapers": []})
|
146
|
+
|
147
|
+
|
148
|
+
def dummy_requests_post_exception(url, headers, params, data, timeout):
|
149
|
+
"""dummy_requests_post_exception"""
|
150
|
+
dummy_requests_post_exception.called_url = url
|
151
|
+
dummy_requests_post_exception.called_headers = headers
|
152
|
+
dummy_requests_post_exception.called_params = params
|
153
|
+
dummy_requests_post_exception.called_data = data
|
154
|
+
dummy_requests_post_exception.called_timeout = timeout
|
155
|
+
# Simulate a network exception.
|
156
|
+
raise requests.exceptions.RequestException("Connection error")
|
157
|
+
|
158
|
+
|
159
|
+
# --- Pytest Fixture to Patch Hydra ---
|
160
|
+
@pytest.fixture(autouse=True)
|
161
|
+
def patch_hydra(monkeypatch):
|
162
|
+
"""Patch Hydra's initialize and compose functions to return dummy objects."""
|
163
|
+
# Patch hydra.initialize to return our dummy context manager.
|
164
|
+
monkeypatch.setattr(
|
165
|
+
hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
|
166
|
+
)
|
167
|
+
# Patch hydra.compose to return our dummy config.
|
168
|
+
monkeypatch.setattr(hydra, "compose", lambda config_name, overrides: dummy_config)
|
169
|
+
|
170
|
+
|
171
|
+
# --- Test Cases ---
|
172
|
+
|
173
|
+
|
174
|
+
def test_multi_paper_rec_success(monkeypatch):
|
175
|
+
"""
|
176
|
+
Test that get_multi_paper_recommendations returns a valid Command object
|
177
|
+
when the API response is successful. Also, ensure that recommendations missing
|
178
|
+
required fields (like authors) are filtered out.
|
179
|
+
"""
|
180
|
+
monkeypatch.setattr(requests, "post", dummy_requests_post_success)
|
181
|
+
|
182
|
+
tool_call_id = "test_tool_call_id"
|
183
|
+
input_data = {
|
184
|
+
"paper_ids": ["p1", "p2"],
|
185
|
+
"tool_call_id": tool_call_id,
|
186
|
+
"limit": 2,
|
187
|
+
"year": "2020",
|
188
|
+
}
|
189
|
+
# Call the tool using .run() with a dictionary input.
|
190
|
+
result = get_multi_paper_recommendations.run(input_data)
|
191
|
+
|
192
|
+
# Validate that the result is a Command with the expected update structure.
|
193
|
+
assert isinstance(result, Command)
|
194
|
+
update = result.update
|
195
|
+
assert "multi_papers" in update
|
196
|
+
|
197
|
+
papers = update["multi_papers"]
|
198
|
+
# Papers with valid 'title' and 'authors' should be included.
|
199
|
+
assert "paperA" in papers
|
200
|
+
assert "paperB" in papers
|
201
|
+
# Paper "paperC" is missing authors and should be filtered out.
|
202
|
+
assert "paperC" not in papers
|
203
|
+
|
204
|
+
# Check that a ToolMessage is included in the messages.
|
205
|
+
messages = update.get("messages", [])
|
206
|
+
assert len(messages) == 1
|
207
|
+
msg = messages[0]
|
208
|
+
assert isinstance(msg, ToolMessage)
|
209
|
+
assert "Recommendations based on multiple papers were successful" in msg.content
|
210
|
+
|
211
|
+
# Verify that the correct parameters were sent to requests.post.
|
212
|
+
called_params = dummy_requests_post_success.called_params
|
213
|
+
assert called_params["limit"] == 2 # Should be min(limit, 500)
|
214
|
+
assert called_params["fields"] == "paperId,title,authors"
|
215
|
+
# The year parameter should be present.
|
216
|
+
assert called_params["year"] == "2020"
|
217
|
+
|
218
|
+
# Also check the payload sent in the data.
|
219
|
+
sent_payload = json.loads(dummy_requests_post_success.called_data)
|
220
|
+
assert sent_payload["positivePaperIds"] == ["p1", "p2"]
|
221
|
+
assert sent_payload["negativePaperIds"] == []
|
222
|
+
|
223
|
+
|
224
|
+
def test_multi_paper_rec_unexpected_format(monkeypatch):
|
225
|
+
"""
|
226
|
+
Test that get_multi_paper_recommendations raises a RuntimeError when the API
|
227
|
+
response does not include the expected 'recommendedPapers' key.
|
228
|
+
"""
|
229
|
+
monkeypatch.setattr(requests, "post", dummy_requests_post_unexpected)
|
230
|
+
tool_call_id = "test_tool_call_id"
|
231
|
+
input_data = {
|
232
|
+
"paper_ids": ["p1", "p2"],
|
233
|
+
"tool_call_id": tool_call_id,
|
234
|
+
}
|
235
|
+
with pytest.raises(
|
236
|
+
RuntimeError,
|
237
|
+
match=(
|
238
|
+
"Unexpected response from Semantic Scholar API. The results could not be "
|
239
|
+
"retrieved due to an unexpected format. "
|
240
|
+
"Please modify your search query and try again."
|
241
|
+
),
|
242
|
+
):
|
243
|
+
get_multi_paper_recommendations.run(input_data)
|
244
|
+
|
245
|
+
|
246
|
+
def test_multi_paper_rec_no_recommendations(monkeypatch):
|
247
|
+
"""
|
248
|
+
Test that get_multi_paper_recommendations raises a RuntimeError when the API
|
249
|
+
returns no recommendations.
|
250
|
+
"""
|
251
|
+
monkeypatch.setattr(requests, "post", dummy_requests_post_no_recs)
|
252
|
+
tool_call_id = "test_tool_call_id"
|
253
|
+
input_data = {
|
254
|
+
"paper_ids": ["p1", "p2"],
|
255
|
+
"tool_call_id": tool_call_id,
|
256
|
+
}
|
257
|
+
with pytest.raises(
|
258
|
+
RuntimeError,
|
259
|
+
match=(
|
260
|
+
"No recommendations were found for your query. Consider refining your search "
|
261
|
+
"by using more specific keywords or different terms."
|
262
|
+
),
|
263
|
+
):
|
264
|
+
get_multi_paper_recommendations.run(input_data)
|
265
|
+
|
266
|
+
|
267
|
+
def test_multi_paper_rec_requests_exception(monkeypatch):
|
268
|
+
"""
|
269
|
+
Test that get_multi_paper_recommendations raises a RuntimeError when requests.post
|
270
|
+
throws an exception.
|
271
|
+
"""
|
272
|
+
monkeypatch.setattr(requests, "post", dummy_requests_post_exception)
|
273
|
+
tool_call_id = "test_tool_call_id"
|
274
|
+
input_data = {
|
275
|
+
"paper_ids": ["p1", "p2"],
|
276
|
+
"tool_call_id": tool_call_id,
|
277
|
+
}
|
278
|
+
with pytest.raises(
|
279
|
+
RuntimeError,
|
280
|
+
match="Failed to connect to Semantic Scholar API. Please retry the same query.",
|
281
|
+
):
|
282
|
+
get_multi_paper_recommendations.run(input_data)
|