aiagents4pharma 1.28.0__py3-none-any.whl → 1.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/__init__.py +1 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +35 -209
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +86 -0
- aiagents4pharma/talk2scholars/agents/s2_agent.py +10 -6
- aiagents4pharma/talk2scholars/agents/zotero_agent.py +12 -6
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +2 -48
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +5 -28
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +5 -21
- aiagents4pharma/talk2scholars/configs/config.yaml +3 -0
- aiagents4pharma/talk2scholars/configs/tools/__init__.py +1 -0
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +1 -1
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -1
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +42 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +1 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +186 -111
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +142 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +154 -0
- aiagents4pharma/talk2scholars/tests/test_s2_display.py +74 -0
- aiagents4pharma/talk2scholars/tests/test_s2_multi.py +282 -0
- aiagents4pharma/talk2scholars/tests/test_s2_query.py +78 -0
- aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +65 -0
- aiagents4pharma/talk2scholars/tests/test_s2_search.py +266 -0
- aiagents4pharma/talk2scholars/tests/test_s2_single.py +274 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_path.py +57 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +412 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_write.py +626 -0
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +17 -0
- aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +43 -0
- aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +108 -0
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +60 -0
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +50 -34
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +8 -8
- aiagents4pharma/talk2scholars/tools/s2/search.py +36 -23
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +44 -38
- aiagents4pharma/talk2scholars/tools/zotero/__init__.py +2 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +5 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +63 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +64 -19
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +247 -0
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/METADATA +6 -5
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/RECORD +48 -30
- aiagents4pharma/talk2scholars/tests/test_call_s2.py +0 -100
- aiagents4pharma/talk2scholars/tests/test_call_zotero.py +0 -94
- aiagents4pharma/talk2scholars/tests/test_s2_tools.py +0 -355
- aiagents4pharma/talk2scholars/tests/test_zotero_tool.py +0 -171
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/top_level.txt +0 -0
@@ -12,4 +12,45 @@ search_params:
|
|
12
12
|
# Item Types and Limit
|
13
13
|
zotero:
|
14
14
|
max_limit: 100
|
15
|
-
filter_item_types:
|
15
|
+
filter_item_types:
|
16
|
+
[
|
17
|
+
"Artwork",
|
18
|
+
"Audio Recording",
|
19
|
+
"Bill",
|
20
|
+
"Blog Post",
|
21
|
+
"Book",
|
22
|
+
"Book Section",
|
23
|
+
"Case",
|
24
|
+
"Conference Paper",
|
25
|
+
"Dataset",
|
26
|
+
"Dictionary Entry",
|
27
|
+
"Document",
|
28
|
+
"E-mail",
|
29
|
+
"Encyclopedia Article",
|
30
|
+
"Film",
|
31
|
+
"Forum Post",
|
32
|
+
"Hearing",
|
33
|
+
"Instant Message",
|
34
|
+
"Interview",
|
35
|
+
"Journal Article",
|
36
|
+
"Letter",
|
37
|
+
"Magazine Article",
|
38
|
+
"Manuscript",
|
39
|
+
"Map",
|
40
|
+
"Newspaper Article",
|
41
|
+
"Patent",
|
42
|
+
"Podcast",
|
43
|
+
"Preprint",
|
44
|
+
"Presentation",
|
45
|
+
"Radio Broadcast",
|
46
|
+
"Report",
|
47
|
+
"Software",
|
48
|
+
"Standard",
|
49
|
+
"Statute",
|
50
|
+
"Thesis",
|
51
|
+
"TV Broadcast",
|
52
|
+
"Video Recording",
|
53
|
+
"Web Page",
|
54
|
+
]
|
55
|
+
|
56
|
+
filter_excluded_types: ["attachment", "note", "annotation"]
|
@@ -3,119 +3,194 @@ Unit tests for main agent functionality.
|
|
3
3
|
Tests the supervisor agent's routing logic and state management.
|
4
4
|
"""
|
5
5
|
|
6
|
-
# pylint: disable=redefined-outer-name
|
7
6
|
# pylint: disable=redefined-outer-name,too-few-public-methods
|
8
|
-
|
9
|
-
from
|
7
|
+
|
8
|
+
from types import SimpleNamespace
|
10
9
|
import pytest
|
11
|
-
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
10
|
+
import hydra
|
11
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
12
|
+
from langchain_openai import ChatOpenAI
|
13
|
+
from pydantic import Field
|
14
|
+
from aiagents4pharma.talk2scholars.agents.main_agent import get_app
|
15
|
+
|
16
|
+
# --- Dummy LLM Implementation ---
|
17
|
+
|
18
|
+
|
19
|
+
class DummyLLM(BaseChatModel):
|
20
|
+
"""A dummy language model implementation for testing purposes."""
|
21
|
+
|
22
|
+
model_name: str = Field(...)
|
23
|
+
|
24
|
+
def _generate(self, prompt, stop=None):
|
25
|
+
"""Generate a response given a prompt."""
|
26
|
+
DummyLLM.called_prompt = prompt
|
27
|
+
return "dummy output"
|
28
|
+
|
29
|
+
@property
|
30
|
+
def _llm_type(self):
|
31
|
+
"""Return the type of the language model."""
|
32
|
+
return "dummy"
|
33
|
+
|
34
|
+
|
35
|
+
# --- Dummy Workflow and Sub-agent Functions ---
|
36
|
+
|
37
|
+
|
38
|
+
class DummyWorkflow:
|
39
|
+
"""A dummy workflow class that records arguments for verification."""
|
40
|
+
|
41
|
+
def __init__(self, supervisor_args=None):
|
42
|
+
"""Initialize the workflow with the given supervisor arguments."""
|
43
|
+
self.supervisor_args = supervisor_args or {}
|
44
|
+
self.checkpointer = None
|
45
|
+
self.name = None
|
46
|
+
|
47
|
+
def compile(self, checkpointer, name):
|
48
|
+
"""Compile the workflow with the given checkpointer and name."""
|
49
|
+
self.checkpointer = checkpointer
|
50
|
+
self.name = name
|
51
|
+
return self
|
52
|
+
|
53
|
+
|
54
|
+
def dummy_get_app_s2(uniq_id, llm_model):
|
55
|
+
"""Return a DummyWorkflow for the S2 agent."""
|
56
|
+
dummy_get_app_s2.called_uniq_id = uniq_id
|
57
|
+
dummy_get_app_s2.called_llm_model = llm_model
|
58
|
+
return DummyWorkflow(supervisor_args={"agent": "s2", "uniq_id": uniq_id})
|
59
|
+
|
60
|
+
|
61
|
+
def dummy_get_app_zotero(uniq_id, llm_model):
|
62
|
+
"""Return a DummyWorkflow for the Zotero agent."""
|
63
|
+
dummy_get_app_zotero.called_uniq_id = uniq_id
|
64
|
+
dummy_get_app_zotero.called_llm_model = llm_model
|
65
|
+
return DummyWorkflow(supervisor_args={"agent": "zotero", "uniq_id": uniq_id})
|
66
|
+
|
67
|
+
|
68
|
+
def dummy_create_supervisor(apps, model, state_schema, **kwargs):
|
69
|
+
"""Return a DummyWorkflow for the supervisor."""
|
70
|
+
dummy_create_supervisor.called_kwargs = kwargs
|
71
|
+
return DummyWorkflow(
|
72
|
+
supervisor_args={
|
73
|
+
"apps": apps,
|
74
|
+
"model": model,
|
75
|
+
"state_schema": state_schema,
|
76
|
+
**kwargs,
|
77
|
+
}
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
# --- Dummy Hydra Configuration Setup ---
|
82
|
+
|
83
|
+
|
84
|
+
class DummyHydraContext:
|
85
|
+
"""A dummy context manager for mocking Hydra's initialize and compose functions."""
|
86
|
+
|
87
|
+
def __enter__(self):
|
88
|
+
"""Return None when entering the context."""
|
89
|
+
return None
|
90
|
+
|
91
|
+
def __exit__(self, exc_type, exc_val, traceback):
|
92
|
+
"""Exit function that does nothing."""
|
93
|
+
return None
|
94
|
+
|
95
|
+
|
96
|
+
def dict_to_namespace(d):
|
97
|
+
"""Convert a dictionary to a SimpleNamespace object."""
|
98
|
+
return SimpleNamespace(
|
99
|
+
**{
|
100
|
+
key: dict_to_namespace(val) if isinstance(val, dict) else val
|
101
|
+
for key, val in d.items()
|
102
|
+
}
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
dummy_config = {
|
107
|
+
"agents": {
|
108
|
+
"talk2scholars": {"main_agent": {"system_prompt": "Dummy system prompt"}}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
class DummyHydraCompose:
|
114
|
+
"""A dummy class that returns a namespace from a dummy config dictionary."""
|
115
|
+
|
116
|
+
def __init__(self, config):
|
117
|
+
"""Constructor that stores the dummy config."""
|
118
|
+
self.config = config
|
119
|
+
|
120
|
+
def __getattr__(self, item):
|
121
|
+
"""Return a namespace from the dummy config."""
|
122
|
+
return dict_to_namespace(self.config.get(item, {}))
|
123
|
+
|
124
|
+
|
125
|
+
# --- Pytest Fixtures to Patch Dependencies ---
|
126
|
+
|
127
|
+
|
128
|
+
@pytest.fixture(autouse=True)
|
129
|
+
def patch_hydra(monkeypatch):
|
130
|
+
"""Patch the hydra.initialize and hydra.compose functions to return dummy objects."""
|
131
|
+
monkeypatch.setattr(
|
132
|
+
hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
|
133
|
+
)
|
134
|
+
monkeypatch.setattr(
|
135
|
+
hydra, "compose", lambda config_name, overrides: DummyHydraCompose(dummy_config)
|
136
|
+
)
|
15
137
|
|
16
138
|
|
17
139
|
@pytest.fixture(autouse=True)
|
18
|
-
def
|
19
|
-
"""
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
def
|
64
|
-
"""
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
"""
|
75
|
-
mock_llm = Mock()
|
76
|
-
thread_id = "test_thread"
|
77
|
-
|
78
|
-
class MockRouter:
|
79
|
-
"""Mock router class."""
|
80
|
-
|
81
|
-
next = random.choice(["s2_agent", "zotero_agent"])
|
82
|
-
|
83
|
-
with (
|
84
|
-
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
85
|
-
patch.object(mock_llm, "invoke", return_value=MockRouter()),
|
86
|
-
):
|
87
|
-
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
88
|
-
mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
|
89
|
-
result = supervisor_node(mock_state)
|
90
|
-
|
91
|
-
# Accept either "s2_agent" or "zotero_agent"
|
92
|
-
assert result.goto in ["s2_agent", "zotero_agent"]
|
93
|
-
|
94
|
-
|
95
|
-
def test_supervisor_node_finish():
|
96
|
-
"""Test that supervisor node correctly handles FINISH case."""
|
97
|
-
mock_llm = Mock()
|
98
|
-
thread_id = "test_thread"
|
99
|
-
|
100
|
-
class MockRouter:
|
101
|
-
"""Mock router class."""
|
102
|
-
|
103
|
-
next = "FINISH"
|
104
|
-
|
105
|
-
class MockAIResponse:
|
106
|
-
"""Mock AI response class."""
|
107
|
-
|
108
|
-
def __init__(self):
|
109
|
-
self.content = "Final AI Response"
|
110
|
-
|
111
|
-
with (
|
112
|
-
patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
|
113
|
-
patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
|
114
|
-
):
|
115
|
-
supervisor_node = make_supervisor_node(mock_llm, thread_id)
|
116
|
-
mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
|
117
|
-
result = supervisor_node(mock_state)
|
118
|
-
assert result.goto == END
|
119
|
-
assert "messages" in result.update
|
120
|
-
assert isinstance(result.update["messages"], AIMessage)
|
121
|
-
assert result.update["messages"].content == "Final AI Response"
|
140
|
+
def patch_sub_agents_and_supervisor(monkeypatch):
|
141
|
+
"""Patch the sub-agents and supervisor creation functions."""
|
142
|
+
monkeypatch.setattr(
|
143
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_s2", dummy_get_app_s2
|
144
|
+
)
|
145
|
+
monkeypatch.setattr(
|
146
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_zotero",
|
147
|
+
dummy_get_app_zotero,
|
148
|
+
)
|
149
|
+
monkeypatch.setattr(
|
150
|
+
"aiagents4pharma.talk2scholars.agents.main_agent.create_supervisor",
|
151
|
+
dummy_create_supervisor,
|
152
|
+
)
|
153
|
+
|
154
|
+
|
155
|
+
# --- Test Cases ---
|
156
|
+
|
157
|
+
|
158
|
+
def test_dummy_llm_generate():
|
159
|
+
"""Test the dummy LLM's generate function."""
|
160
|
+
dummy = DummyLLM(model_name="test-model")
|
161
|
+
output = getattr(dummy, "_generate")("any prompt")
|
162
|
+
assert output == "dummy output"
|
163
|
+
|
164
|
+
|
165
|
+
def test_dummy_llm_llm_type():
|
166
|
+
"""Test the dummy LLM's _llm_type property."""
|
167
|
+
dummy = DummyLLM(model_name="test-model")
|
168
|
+
assert getattr(dummy, "_llm_type") == "dummy"
|
169
|
+
|
170
|
+
|
171
|
+
def test_get_app_with_gpt4o_mini():
|
172
|
+
"""
|
173
|
+
Test that get_app replaces a 'gpt-4o-mini' LLM with a new ChatOpenAI instance.
|
174
|
+
"""
|
175
|
+
uniq_id = "test_thread"
|
176
|
+
dummy_llm = DummyLLM(model_name="gpt-4o-mini")
|
177
|
+
app = get_app(uniq_id, dummy_llm)
|
178
|
+
|
179
|
+
supervisor_args = getattr(app, "supervisor_args", {})
|
180
|
+
assert isinstance(supervisor_args.get("model"), ChatOpenAI)
|
181
|
+
assert supervisor_args.get("prompt") == "Dummy system prompt"
|
182
|
+
assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
|
183
|
+
|
184
|
+
|
185
|
+
def test_get_app_with_other_model():
|
186
|
+
"""
|
187
|
+
Test that get_app does not replace the LLM if its model_name is not 'gpt-4o-mini'.
|
188
|
+
"""
|
189
|
+
uniq_id = "test_thread_2"
|
190
|
+
dummy_llm = DummyLLM(model_name="other-model")
|
191
|
+
app = get_app(uniq_id, dummy_llm)
|
192
|
+
|
193
|
+
supervisor_args = getattr(app, "supervisor_args", {})
|
194
|
+
assert supervisor_args.get("model") is dummy_llm
|
195
|
+
assert supervisor_args.get("prompt") == "Dummy system prompt"
|
196
|
+
assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
|
@@ -0,0 +1,142 @@
|
|
1
|
+
"""Unit tests for the paper download agent in Talk2Scholars."""
|
2
|
+
|
3
|
+
from unittest import mock
|
4
|
+
import pytest
|
5
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
6
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
7
|
+
from ..agents.paper_download_agent import get_app
|
8
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture(autouse=True)
|
12
|
+
def mock_hydra_fixture():
|
13
|
+
"""Mocks Hydra configuration for tests."""
|
14
|
+
with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
|
15
|
+
cfg_mock = mock.MagicMock()
|
16
|
+
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
|
17
|
+
cfg_mock.agents.talk2scholars.paper_download_agent.prompt = "Test prompt"
|
18
|
+
mock_compose.return_value = cfg_mock
|
19
|
+
yield mock_compose
|
20
|
+
|
21
|
+
|
22
|
+
@pytest.fixture
|
23
|
+
def mock_tools_fixture():
|
24
|
+
"""Mocks paper download tools to prevent real HTTP calls."""
|
25
|
+
with (
|
26
|
+
mock.patch(
|
27
|
+
"aiagents4pharma.talk2scholars.tools.paper_download."
|
28
|
+
"download_arxiv_input.download_arxiv_paper"
|
29
|
+
) as mock_download_arxiv_paper,
|
30
|
+
mock.patch(
|
31
|
+
"aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
|
32
|
+
) as mock_query_results,
|
33
|
+
):
|
34
|
+
mock_download_arxiv_paper.return_value = {
|
35
|
+
"pdf_data": {"dummy_key": "dummy_value"}
|
36
|
+
}
|
37
|
+
mock_query_results.return_value = {
|
38
|
+
"result": "Mocked Query Result"
|
39
|
+
}
|
40
|
+
yield [mock_download_arxiv_paper, mock_query_results]
|
41
|
+
|
42
|
+
@pytest.mark.usefixtures("mock_hydra_fixture")
|
43
|
+
def test_paper_download_agent_initialization():
|
44
|
+
"""Ensures the paper download agent initializes properly with a prompt."""
|
45
|
+
thread_id = "test_thread_paper_dl"
|
46
|
+
llm_mock = mock.Mock(spec=BaseChatModel) # Mock LLM
|
47
|
+
|
48
|
+
with mock.patch(
|
49
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
|
50
|
+
) as mock_create_agent:
|
51
|
+
mock_create_agent.return_value = mock.Mock()
|
52
|
+
|
53
|
+
app = get_app(thread_id, llm_mock)
|
54
|
+
assert app is not None, "The agent app should be successfully created."
|
55
|
+
assert mock_create_agent.called
|
56
|
+
|
57
|
+
def test_paper_download_agent_invocation():
|
58
|
+
"""Verifies agent processes queries and updates state correctly."""
|
59
|
+
_ = mock_tools_fixture # Prevents unused-argument warning
|
60
|
+
thread_id = "test_thread_paper_dl"
|
61
|
+
mock_state = Talk2Scholars(
|
62
|
+
messages=[HumanMessage(content="Download paper 1234.5678")]
|
63
|
+
)
|
64
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
65
|
+
|
66
|
+
with mock.patch(
|
67
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
|
68
|
+
) as mock_create_agent:
|
69
|
+
mock_agent = mock.Mock()
|
70
|
+
mock_create_agent.return_value = mock_agent
|
71
|
+
mock_agent.invoke.return_value = {
|
72
|
+
"messages": [AIMessage(content="Here is the paper")],
|
73
|
+
"pdf_data": {"file_bytes": b"FAKE_PDF_CONTENTS"},
|
74
|
+
}
|
75
|
+
|
76
|
+
|
77
|
+
app = get_app(thread_id, llm_mock)
|
78
|
+
result = app.invoke(
|
79
|
+
mock_state,
|
80
|
+
config={
|
81
|
+
"configurable": {
|
82
|
+
"thread_id": thread_id,
|
83
|
+
"checkpoint_ns": "test_ns",
|
84
|
+
"checkpoint_id": "test_checkpoint",
|
85
|
+
}
|
86
|
+
},
|
87
|
+
)
|
88
|
+
|
89
|
+
assert "messages" in result
|
90
|
+
assert "pdf_data" in result
|
91
|
+
|
92
|
+
|
93
|
+
def test_paper_download_agent_tools_assignment(request): # Keep fixture name
|
94
|
+
"""Checks correct tool assignment (download_arxiv_paper, query_results)."""
|
95
|
+
thread_id = "test_thread_paper_dl"
|
96
|
+
mock_tools = request.getfixturevalue("mock_tools_fixture")
|
97
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
98
|
+
|
99
|
+
with (
|
100
|
+
mock.patch(
|
101
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
|
102
|
+
) as mock_create_agent,
|
103
|
+
mock.patch(
|
104
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.ToolNode"
|
105
|
+
) as mock_toolnode,
|
106
|
+
):
|
107
|
+
mock_agent = mock.Mock()
|
108
|
+
mock_create_agent.return_value = mock_agent
|
109
|
+
mock_tool_instance = mock.Mock()
|
110
|
+
mock_tool_instance.tools = mock_tools
|
111
|
+
mock_toolnode.return_value= mock_tool_instance
|
112
|
+
|
113
|
+
get_app(thread_id, llm_mock)
|
114
|
+
assert mock_toolnode.called
|
115
|
+
assert len(mock_tool_instance.tools) == 2
|
116
|
+
|
117
|
+
|
118
|
+
def test_paper_download_agent_hydra_failure():
|
119
|
+
"""Confirms the agent gracefully handles exceptions if Hydra fails."""
|
120
|
+
thread_id = "test_thread_paper_dl"
|
121
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
122
|
+
|
123
|
+
with mock.patch("hydra.initialize", side_effect=Exception("Mock Hydra failure")):
|
124
|
+
with pytest.raises(Exception) as exc_info:
|
125
|
+
get_app(thread_id, llm_mock)
|
126
|
+
assert "Mock Hydra failure" in str(exc_info.value)
|
127
|
+
|
128
|
+
|
129
|
+
def test_paper_download_agent_model_failure():
|
130
|
+
"""Ensures agent handles model-related failures gracefully."""
|
131
|
+
thread_id = "test_thread_paper_dl"
|
132
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
133
|
+
|
134
|
+
with mock.patch(
|
135
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent",
|
136
|
+
side_effect=Exception("Mock model failure"),
|
137
|
+
):
|
138
|
+
with pytest.raises(Exception) as exc_info:
|
139
|
+
get_app(thread_id, llm_mock)
|
140
|
+
assert "Mock model failure" in str(exc_info.value), (
|
141
|
+
"Model initialization failure should raise an exception."
|
142
|
+
)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for arXiv paper downloading functionality, including:
|
3
|
+
- AbstractPaperDownloader (base class)
|
4
|
+
- ArxivPaperDownloader (arXiv-specific implementation)
|
5
|
+
- download_arxiv_paper tool function.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from unittest.mock import patch, MagicMock
|
9
|
+
import pytest
|
10
|
+
import requests
|
11
|
+
from requests.exceptions import HTTPError
|
12
|
+
from langgraph.types import Command
|
13
|
+
from langchain_core.messages import ToolMessage
|
14
|
+
|
15
|
+
# Import the classes and function under test
|
16
|
+
from aiagents4pharma.talk2scholars.tools.paper_download.abstract_downloader import (
|
17
|
+
AbstractPaperDownloader,
|
18
|
+
)
|
19
|
+
from aiagents4pharma.talk2scholars.tools.paper_download.arxiv_downloader import (
|
20
|
+
ArxivPaperDownloader,
|
21
|
+
)
|
22
|
+
from aiagents4pharma.talk2scholars.tools.paper_download.download_arxiv_input import (
|
23
|
+
download_arxiv_paper,
|
24
|
+
)
|
25
|
+
|
26
|
+
@pytest.mark.parametrize("class_obj", [AbstractPaperDownloader])
|
27
|
+
|
28
|
+
def test_abstract_downloader_cannot_be_instantiated(class_obj):
|
29
|
+
"""
|
30
|
+
Validates that AbstractPaperDownloader is indeed abstract and raises TypeError
|
31
|
+
if anyone attempts to instantiate it directly.
|
32
|
+
"""
|
33
|
+
with pytest.raises(TypeError):
|
34
|
+
class_obj()
|
35
|
+
|
36
|
+
|
37
|
+
@pytest.fixture(name="arxiv_downloader_fixture")
|
38
|
+
@pytest.mark.usefixtures("mock_hydra_config_setup")
|
39
|
+
def fixture_arxiv_downloader():
|
40
|
+
"""
|
41
|
+
Provides an ArxivPaperDownloader instance with a mocked Hydra config.
|
42
|
+
"""
|
43
|
+
return ArxivPaperDownloader()
|
44
|
+
|
45
|
+
|
46
|
+
def test_fetch_metadata_success(arxiv_downloader_fixture,):
|
47
|
+
"""
|
48
|
+
Ensures fetch_metadata retrieves XML data correctly, given a successful HTTP response.
|
49
|
+
"""
|
50
|
+
mock_response = MagicMock()
|
51
|
+
mock_response.text = "<xml>Mock ArXiv Metadata</xml>"
|
52
|
+
mock_response.raise_for_status = MagicMock()
|
53
|
+
|
54
|
+
with patch.object(requests, "get", return_value=mock_response) as mock_get:
|
55
|
+
paper_id = "1234.5678"
|
56
|
+
result = arxiv_downloader_fixture.fetch_metadata(paper_id)
|
57
|
+
mock_get.assert_called_once_with(
|
58
|
+
"http://export.arxiv.org/api/query?search_query=id:1234.5678&start=0&max_results=1",
|
59
|
+
timeout=10,
|
60
|
+
)
|
61
|
+
assert result["xml"] == "<xml>Mock ArXiv Metadata</xml>"
|
62
|
+
|
63
|
+
|
64
|
+
def test_fetch_metadata_http_error(arxiv_downloader_fixture):
|
65
|
+
"""
|
66
|
+
Validates that fetch_metadata raises HTTPError when the response indicates a failure.
|
67
|
+
"""
|
68
|
+
mock_response = MagicMock()
|
69
|
+
mock_response.raise_for_status.side_effect = HTTPError("Mocked HTTP failure")
|
70
|
+
|
71
|
+
with patch.object(requests, "get", return_value=mock_response):
|
72
|
+
with pytest.raises(HTTPError):
|
73
|
+
arxiv_downloader_fixture.fetch_metadata("invalid_id")
|
74
|
+
|
75
|
+
|
76
|
+
def test_download_pdf_success(arxiv_downloader_fixture):
|
77
|
+
"""
|
78
|
+
Tests that download_pdf fetches the PDF link from metadata and successfully
|
79
|
+
retrieves the binary content.
|
80
|
+
"""
|
81
|
+
mock_metadata = {
|
82
|
+
"xml": """
|
83
|
+
<feed xmlns="http://www.w3.org/2005/Atom">
|
84
|
+
<entry>
|
85
|
+
<link title="pdf" href="http://test.arxiv.org/pdf/1234.5678v1.pdf"/>
|
86
|
+
</entry>
|
87
|
+
</feed>
|
88
|
+
"""
|
89
|
+
}
|
90
|
+
|
91
|
+
mock_pdf_response = MagicMock()
|
92
|
+
mock_pdf_response.raise_for_status = MagicMock()
|
93
|
+
mock_pdf_response.iter_content = lambda chunk_size: [b"FAKE_PDF_CONTENT"]
|
94
|
+
|
95
|
+
with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
|
96
|
+
with patch.object(requests, "get", return_value=mock_pdf_response) as mock_get:
|
97
|
+
result = arxiv_downloader_fixture.download_pdf("1234.5678")
|
98
|
+
assert result["pdf_object"] == b"FAKE_PDF_CONTENT"
|
99
|
+
assert result["pdf_url"] == "http://test.arxiv.org/pdf/1234.5678v1.pdf"
|
100
|
+
assert result["arxiv_id"] == "1234.5678"
|
101
|
+
mock_get.assert_called_once_with(
|
102
|
+
"http://test.arxiv.org/pdf/1234.5678v1.pdf",
|
103
|
+
stream=True,
|
104
|
+
timeout=10,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def test_download_pdf_no_pdf_link(arxiv_downloader_fixture):
|
109
|
+
"""
|
110
|
+
Ensures a RuntimeError is raised if no <link> with title="pdf" is found in the XML.
|
111
|
+
"""
|
112
|
+
mock_metadata = {"xml": "<feed></feed>"}
|
113
|
+
|
114
|
+
with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
|
115
|
+
with pytest.raises(RuntimeError, match="Failed to download PDF"):
|
116
|
+
arxiv_downloader_fixture.download_pdf("1234.5678")
|
117
|
+
|
118
|
+
|
119
|
+
def test_download_arxiv_paper_tool_success(arxiv_downloader_fixture):
|
120
|
+
"""
|
121
|
+
Validates download_arxiv_paper orchestrates the ArxivPaperDownloader correctly,
|
122
|
+
returning a Command with PDF data and success messages.
|
123
|
+
"""
|
124
|
+
mock_metadata = {"xml": "<mockxml></mockxml>"}
|
125
|
+
mock_pdf_response = {
|
126
|
+
"pdf_object": b"FAKE_PDF_CONTENT",
|
127
|
+
"pdf_url": "http://test.arxiv.org/mock.pdf",
|
128
|
+
"arxiv_id": "9999.8888",
|
129
|
+
}
|
130
|
+
|
131
|
+
with patch(
|
132
|
+
"aiagents4pharma.talk2scholars.tools.paper_download.download_arxiv_input."
|
133
|
+
"ArxivPaperDownloader",
|
134
|
+
return_value=arxiv_downloader_fixture,
|
135
|
+
):
|
136
|
+
with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
|
137
|
+
with patch.object(
|
138
|
+
arxiv_downloader_fixture,
|
139
|
+
"download_pdf",
|
140
|
+
return_value=mock_pdf_response,
|
141
|
+
):
|
142
|
+
command_result = download_arxiv_paper.invoke(
|
143
|
+
{"arxiv_id": "9999.8888", "tool_call_id": "test_tool_call"}
|
144
|
+
)
|
145
|
+
|
146
|
+
assert isinstance(command_result, Command)
|
147
|
+
assert "pdf_data" in command_result.update
|
148
|
+
assert command_result.update["pdf_data"] == mock_pdf_response
|
149
|
+
|
150
|
+
messages = command_result.update.get("messages", [])
|
151
|
+
assert len(messages) == 1
|
152
|
+
assert isinstance(messages[0], ToolMessage)
|
153
|
+
assert "Successfully downloaded PDF" in messages[0].content
|
154
|
+
assert "9999.8888" in messages[0].content
|