aiagents4pharma 1.29.0__py3-none-any.whl → 1.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/__init__.py +1 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +18 -10
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +85 -0
- aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +18 -9
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +2 -2
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/config.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +6 -1
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +7 -1
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +6 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -1
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +4 -0
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +84 -53
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +24 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +142 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +154 -0
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +79 -15
- aiagents4pharma/talk2scholars/tests/test_routing_logic.py +12 -8
- aiagents4pharma/talk2scholars/tests/test_s2_multi.py +27 -4
- aiagents4pharma/talk2scholars/tests/test_s2_search.py +19 -3
- aiagents4pharma/talk2scholars/tests/test_s2_single.py +27 -3
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +17 -10
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +17 -0
- aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +45 -0
- aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +115 -0
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +64 -0
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -26
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +46 -22
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +1 -1
- aiagents4pharma/talk2scholars/tools/s2/search.py +40 -12
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +42 -16
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +33 -16
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +39 -7
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/RECORD +41 -32
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/WHEEL +1 -1
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
"""Unit tests for the paper download agent in Talk2Scholars."""
|
2
|
+
|
3
|
+
from unittest import mock
|
4
|
+
import pytest
|
5
|
+
from langchain_core.messages import HumanMessage, AIMessage
|
6
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
7
|
+
from ..agents.paper_download_agent import get_app
|
8
|
+
from ..state.state_talk2scholars import Talk2Scholars
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture(autouse=True)
|
12
|
+
def mock_hydra_fixture():
|
13
|
+
"""Mocks Hydra configuration for tests."""
|
14
|
+
with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
|
15
|
+
cfg_mock = mock.MagicMock()
|
16
|
+
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
|
17
|
+
cfg_mock.agents.talk2scholars.paper_download_agent.prompt = "Test prompt"
|
18
|
+
mock_compose.return_value = cfg_mock
|
19
|
+
yield mock_compose
|
20
|
+
|
21
|
+
|
22
|
+
@pytest.fixture
|
23
|
+
def mock_tools_fixture():
|
24
|
+
"""Mocks paper download tools to prevent real HTTP calls."""
|
25
|
+
with (
|
26
|
+
mock.patch(
|
27
|
+
"aiagents4pharma.talk2scholars.tools.paper_download."
|
28
|
+
"download_arxiv_input.download_arxiv_paper"
|
29
|
+
) as mock_download_arxiv_paper,
|
30
|
+
mock.patch(
|
31
|
+
"aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
|
32
|
+
) as mock_query_results,
|
33
|
+
):
|
34
|
+
mock_download_arxiv_paper.return_value = {
|
35
|
+
"pdf_data": {"dummy_key": "dummy_value"}
|
36
|
+
}
|
37
|
+
mock_query_results.return_value = {
|
38
|
+
"result": "Mocked Query Result"
|
39
|
+
}
|
40
|
+
yield [mock_download_arxiv_paper, mock_query_results]
|
41
|
+
|
42
|
+
@pytest.mark.usefixtures("mock_hydra_fixture")
|
43
|
+
def test_paper_download_agent_initialization():
|
44
|
+
"""Ensures the paper download agent initializes properly with a prompt."""
|
45
|
+
thread_id = "test_thread_paper_dl"
|
46
|
+
llm_mock = mock.Mock(spec=BaseChatModel) # Mock LLM
|
47
|
+
|
48
|
+
with mock.patch(
|
49
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
|
50
|
+
) as mock_create_agent:
|
51
|
+
mock_create_agent.return_value = mock.Mock()
|
52
|
+
|
53
|
+
app = get_app(thread_id, llm_mock)
|
54
|
+
assert app is not None, "The agent app should be successfully created."
|
55
|
+
assert mock_create_agent.called
|
56
|
+
|
57
|
+
def test_paper_download_agent_invocation():
|
58
|
+
"""Verifies agent processes queries and updates state correctly."""
|
59
|
+
_ = mock_tools_fixture # Prevents unused-argument warning
|
60
|
+
thread_id = "test_thread_paper_dl"
|
61
|
+
mock_state = Talk2Scholars(
|
62
|
+
messages=[HumanMessage(content="Download paper 1234.5678")]
|
63
|
+
)
|
64
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
65
|
+
|
66
|
+
with mock.patch(
|
67
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
|
68
|
+
) as mock_create_agent:
|
69
|
+
mock_agent = mock.Mock()
|
70
|
+
mock_create_agent.return_value = mock_agent
|
71
|
+
mock_agent.invoke.return_value = {
|
72
|
+
"messages": [AIMessage(content="Here is the paper")],
|
73
|
+
"pdf_data": {"file_bytes": b"FAKE_PDF_CONTENTS"},
|
74
|
+
}
|
75
|
+
|
76
|
+
|
77
|
+
app = get_app(thread_id, llm_mock)
|
78
|
+
result = app.invoke(
|
79
|
+
mock_state,
|
80
|
+
config={
|
81
|
+
"configurable": {
|
82
|
+
"thread_id": thread_id,
|
83
|
+
"checkpoint_ns": "test_ns",
|
84
|
+
"checkpoint_id": "test_checkpoint",
|
85
|
+
}
|
86
|
+
},
|
87
|
+
)
|
88
|
+
|
89
|
+
assert "messages" in result
|
90
|
+
assert "pdf_data" in result
|
91
|
+
|
92
|
+
|
93
|
+
def test_paper_download_agent_tools_assignment(request): # Keep fixture name
|
94
|
+
"""Checks correct tool assignment (download_arxiv_paper, query_results)."""
|
95
|
+
thread_id = "test_thread_paper_dl"
|
96
|
+
mock_tools = request.getfixturevalue("mock_tools_fixture")
|
97
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
98
|
+
|
99
|
+
with (
|
100
|
+
mock.patch(
|
101
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
|
102
|
+
) as mock_create_agent,
|
103
|
+
mock.patch(
|
104
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.ToolNode"
|
105
|
+
) as mock_toolnode,
|
106
|
+
):
|
107
|
+
mock_agent = mock.Mock()
|
108
|
+
mock_create_agent.return_value = mock_agent
|
109
|
+
mock_tool_instance = mock.Mock()
|
110
|
+
mock_tool_instance.tools = mock_tools
|
111
|
+
mock_toolnode.return_value= mock_tool_instance
|
112
|
+
|
113
|
+
get_app(thread_id, llm_mock)
|
114
|
+
assert mock_toolnode.called
|
115
|
+
assert len(mock_tool_instance.tools) == 2
|
116
|
+
|
117
|
+
|
118
|
+
def test_paper_download_agent_hydra_failure():
|
119
|
+
"""Confirms the agent gracefully handles exceptions if Hydra fails."""
|
120
|
+
thread_id = "test_thread_paper_dl"
|
121
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
122
|
+
|
123
|
+
with mock.patch("hydra.initialize", side_effect=Exception("Mock Hydra failure")):
|
124
|
+
with pytest.raises(Exception) as exc_info:
|
125
|
+
get_app(thread_id, llm_mock)
|
126
|
+
assert "Mock Hydra failure" in str(exc_info.value)
|
127
|
+
|
128
|
+
|
129
|
+
def test_paper_download_agent_model_failure():
|
130
|
+
"""Ensures agent handles model-related failures gracefully."""
|
131
|
+
thread_id = "test_thread_paper_dl"
|
132
|
+
llm_mock = mock.Mock(spec=BaseChatModel)
|
133
|
+
|
134
|
+
with mock.patch(
|
135
|
+
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent",
|
136
|
+
side_effect=Exception("Mock model failure"),
|
137
|
+
):
|
138
|
+
with pytest.raises(Exception) as exc_info:
|
139
|
+
get_app(thread_id, llm_mock)
|
140
|
+
assert "Mock model failure" in str(exc_info.value), (
|
141
|
+
"Model initialization failure should raise an exception."
|
142
|
+
)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for arXiv paper downloading functionality, including:
|
3
|
+
- AbstractPaperDownloader (base class)
|
4
|
+
- ArxivPaperDownloader (arXiv-specific implementation)
|
5
|
+
- download_arxiv_paper tool function.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from unittest.mock import patch, MagicMock
|
9
|
+
import pytest
|
10
|
+
import requests
|
11
|
+
from requests.exceptions import HTTPError
|
12
|
+
from langgraph.types import Command
|
13
|
+
from langchain_core.messages import ToolMessage
|
14
|
+
|
15
|
+
# Import the classes and function under test
|
16
|
+
from aiagents4pharma.talk2scholars.tools.paper_download.abstract_downloader import (
|
17
|
+
AbstractPaperDownloader,
|
18
|
+
)
|
19
|
+
from aiagents4pharma.talk2scholars.tools.paper_download.arxiv_downloader import (
|
20
|
+
ArxivPaperDownloader,
|
21
|
+
)
|
22
|
+
from aiagents4pharma.talk2scholars.tools.paper_download.download_arxiv_input import (
|
23
|
+
download_arxiv_paper,
|
24
|
+
)
|
25
|
+
|
26
|
+
@pytest.mark.parametrize("class_obj", [AbstractPaperDownloader])
|
27
|
+
|
28
|
+
def test_abstract_downloader_cannot_be_instantiated(class_obj):
|
29
|
+
"""
|
30
|
+
Validates that AbstractPaperDownloader is indeed abstract and raises TypeError
|
31
|
+
if anyone attempts to instantiate it directly.
|
32
|
+
"""
|
33
|
+
with pytest.raises(TypeError):
|
34
|
+
class_obj()
|
35
|
+
|
36
|
+
|
37
|
+
@pytest.fixture(name="arxiv_downloader_fixture")
|
38
|
+
@pytest.mark.usefixtures("mock_hydra_config_setup")
|
39
|
+
def fixture_arxiv_downloader():
|
40
|
+
"""
|
41
|
+
Provides an ArxivPaperDownloader instance with a mocked Hydra config.
|
42
|
+
"""
|
43
|
+
return ArxivPaperDownloader()
|
44
|
+
|
45
|
+
|
46
|
+
def test_fetch_metadata_success(arxiv_downloader_fixture,):
|
47
|
+
"""
|
48
|
+
Ensures fetch_metadata retrieves XML data correctly, given a successful HTTP response.
|
49
|
+
"""
|
50
|
+
mock_response = MagicMock()
|
51
|
+
mock_response.text = "<xml>Mock ArXiv Metadata</xml>"
|
52
|
+
mock_response.raise_for_status = MagicMock()
|
53
|
+
|
54
|
+
with patch.object(requests, "get", return_value=mock_response) as mock_get:
|
55
|
+
paper_id = "1234.5678"
|
56
|
+
result = arxiv_downloader_fixture.fetch_metadata(paper_id)
|
57
|
+
mock_get.assert_called_once_with(
|
58
|
+
"http://export.arxiv.org/api/query?search_query=id:1234.5678&start=0&max_results=1",
|
59
|
+
timeout=10,
|
60
|
+
)
|
61
|
+
assert result["xml"] == "<xml>Mock ArXiv Metadata</xml>"
|
62
|
+
|
63
|
+
|
64
|
+
def test_fetch_metadata_http_error(arxiv_downloader_fixture):
|
65
|
+
"""
|
66
|
+
Validates that fetch_metadata raises HTTPError when the response indicates a failure.
|
67
|
+
"""
|
68
|
+
mock_response = MagicMock()
|
69
|
+
mock_response.raise_for_status.side_effect = HTTPError("Mocked HTTP failure")
|
70
|
+
|
71
|
+
with patch.object(requests, "get", return_value=mock_response):
|
72
|
+
with pytest.raises(HTTPError):
|
73
|
+
arxiv_downloader_fixture.fetch_metadata("invalid_id")
|
74
|
+
|
75
|
+
|
76
|
+
def test_download_pdf_success(arxiv_downloader_fixture):
|
77
|
+
"""
|
78
|
+
Tests that download_pdf fetches the PDF link from metadata and successfully
|
79
|
+
retrieves the binary content.
|
80
|
+
"""
|
81
|
+
mock_metadata = {
|
82
|
+
"xml": """
|
83
|
+
<feed xmlns="http://www.w3.org/2005/Atom">
|
84
|
+
<entry>
|
85
|
+
<link title="pdf" href="http://test.arxiv.org/pdf/1234.5678v1.pdf"/>
|
86
|
+
</entry>
|
87
|
+
</feed>
|
88
|
+
"""
|
89
|
+
}
|
90
|
+
|
91
|
+
mock_pdf_response = MagicMock()
|
92
|
+
mock_pdf_response.raise_for_status = MagicMock()
|
93
|
+
mock_pdf_response.iter_content = lambda chunk_size: [b"FAKE_PDF_CONTENT"]
|
94
|
+
|
95
|
+
with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
|
96
|
+
with patch.object(requests, "get", return_value=mock_pdf_response) as mock_get:
|
97
|
+
result = arxiv_downloader_fixture.download_pdf("1234.5678")
|
98
|
+
assert result["pdf_object"] == b"FAKE_PDF_CONTENT"
|
99
|
+
assert result["pdf_url"] == "http://test.arxiv.org/pdf/1234.5678v1.pdf"
|
100
|
+
assert result["arxiv_id"] == "1234.5678"
|
101
|
+
mock_get.assert_called_once_with(
|
102
|
+
"http://test.arxiv.org/pdf/1234.5678v1.pdf",
|
103
|
+
stream=True,
|
104
|
+
timeout=10,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def test_download_pdf_no_pdf_link(arxiv_downloader_fixture):
|
109
|
+
"""
|
110
|
+
Ensures a RuntimeError is raised if no <link> with title="pdf" is found in the XML.
|
111
|
+
"""
|
112
|
+
mock_metadata = {"xml": "<feed></feed>"}
|
113
|
+
|
114
|
+
with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
|
115
|
+
with pytest.raises(RuntimeError, match="Failed to download PDF"):
|
116
|
+
arxiv_downloader_fixture.download_pdf("1234.5678")
|
117
|
+
|
118
|
+
|
119
|
+
def test_download_arxiv_paper_tool_success(arxiv_downloader_fixture):
|
120
|
+
"""
|
121
|
+
Validates download_arxiv_paper orchestrates the ArxivPaperDownloader correctly,
|
122
|
+
returning a Command with PDF data and success messages.
|
123
|
+
"""
|
124
|
+
mock_metadata = {"xml": "<mockxml></mockxml>"}
|
125
|
+
mock_pdf_response = {
|
126
|
+
"pdf_object": b"FAKE_PDF_CONTENT",
|
127
|
+
"pdf_url": "http://test.arxiv.org/mock.pdf",
|
128
|
+
"arxiv_id": "9999.8888",
|
129
|
+
}
|
130
|
+
|
131
|
+
with patch(
|
132
|
+
"aiagents4pharma.talk2scholars.tools.paper_download.download_arxiv_input."
|
133
|
+
"ArxivPaperDownloader",
|
134
|
+
return_value=arxiv_downloader_fixture,
|
135
|
+
):
|
136
|
+
with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
|
137
|
+
with patch.object(
|
138
|
+
arxiv_downloader_fixture,
|
139
|
+
"download_pdf",
|
140
|
+
return_value=mock_pdf_response,
|
141
|
+
):
|
142
|
+
command_result = download_arxiv_paper.invoke(
|
143
|
+
{"arxiv_id": "9999.8888", "tool_call_id": "test_tool_call"}
|
144
|
+
)
|
145
|
+
|
146
|
+
assert isinstance(command_result, Command)
|
147
|
+
assert "pdf_data" in command_result.update
|
148
|
+
assert command_result.update["pdf_data"] == mock_pdf_response
|
149
|
+
|
150
|
+
messages = command_result.update.get("messages", [])
|
151
|
+
assert len(messages) == 1
|
152
|
+
assert isinstance(messages[0], ToolMessage)
|
153
|
+
assert "Successfully downloaded PDF" in messages[0].content
|
154
|
+
assert "9999.8888" in messages[0].content
|
@@ -3,7 +3,6 @@ Unit tests for question_and_answer tool functionality.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from langchain.docstore.document import Document
|
6
|
-
|
7
6
|
from ..tools.pdf import question_and_answer
|
8
7
|
from ..tools.pdf.question_and_answer import (
|
9
8
|
extract_text_from_pdf_data,
|
@@ -11,6 +10,8 @@ from ..tools.pdf.question_and_answer import (
|
|
11
10
|
generate_answer,
|
12
11
|
)
|
13
12
|
|
13
|
+
# pylint: disable=redefined-outer-name,too-few-public-methods
|
14
|
+
|
14
15
|
|
15
16
|
def test_extract_text_from_pdf_data():
|
16
17
|
"""
|
@@ -46,14 +47,14 @@ DUMMY_PDF_BYTES = (
|
|
46
47
|
)
|
47
48
|
|
48
49
|
|
49
|
-
def
|
50
|
+
def fake_generate_answer2(question, pdf_url, _text_embedding_model):
|
50
51
|
"""
|
51
|
-
Fake
|
52
|
+
Fake generate_answer2 function to bypass external dependencies.
|
52
53
|
"""
|
53
54
|
return {
|
54
55
|
"answer": "Mock answer",
|
55
56
|
"question": question,
|
56
|
-
"
|
57
|
+
"pdf_url": pdf_url,
|
57
58
|
}
|
58
59
|
|
59
60
|
|
@@ -61,30 +62,31 @@ def test_question_and_answer_tool_success(monkeypatch):
|
|
61
62
|
"""
|
62
63
|
Test that question_and_answer_tool returns the expected result on success.
|
63
64
|
"""
|
64
|
-
|
65
|
-
|
66
|
-
)
|
67
|
-
# Create a valid state with pdf_data
|
68
|
-
# and include a dummy llm_model.
|
65
|
+
# Patch generate_answer2 because the tool calls that.
|
66
|
+
monkeypatch.setattr(question_and_answer, "generate_answer2", fake_generate_answer2)
|
67
|
+
dummy_text_embedding_model = object() # Provide a dummy text embedding model.
|
68
|
+
# Create a valid state with pdf_data and include dummy llm_model and text_embedding_model.
|
69
69
|
state = {
|
70
70
|
"pdf_data": {"pdf_object": DUMMY_PDF_BYTES, "pdf_url": "http://dummy.url"},
|
71
71
|
"llm_model": object(), # Provide a dummy LLM model instance.
|
72
|
+
"text_embedding_model": dummy_text_embedding_model,
|
72
73
|
}
|
73
74
|
question = "What is in the PDF?"
|
74
|
-
# Call the underlying function directly via .func to bypass the StructuredTool wrapper.
|
75
75
|
result = question_and_answer_tool.func(
|
76
76
|
question=question, tool_call_id="test_call_id", state=state
|
77
77
|
)
|
78
78
|
assert result["answer"] == "Mock answer"
|
79
79
|
assert result["question"] == question
|
80
|
-
assert result["
|
80
|
+
assert result["pdf_url"] == "http://dummy.url"
|
81
81
|
|
82
82
|
|
83
83
|
def test_question_and_answer_tool_no_pdf_data():
|
84
84
|
"""
|
85
85
|
Test that an error is returned if the state lacks the 'pdf_data' key.
|
86
86
|
"""
|
87
|
-
state = {
|
87
|
+
state = {
|
88
|
+
"text_embedding_model": object(), # Added to avoid KeyError.
|
89
|
+
}
|
88
90
|
question = "Any question?"
|
89
91
|
result = question_and_answer_tool.func(
|
90
92
|
question=question, tool_call_id="test_call_id", state=state
|
@@ -97,7 +99,11 @@ def test_question_and_answer_tool_no_pdf_object():
|
|
97
99
|
"""
|
98
100
|
Test that an error is returned if the pdf_object is missing within pdf_data.
|
99
101
|
"""
|
100
|
-
state = {
|
102
|
+
state = {
|
103
|
+
"pdf_data": {"pdf_object": None},
|
104
|
+
"text_embedding_model": object(), # Added to avoid KeyError.
|
105
|
+
"llm_model": object(), # Dummy LLM model.
|
106
|
+
}
|
101
107
|
question = "Any question?"
|
102
108
|
result = question_and_answer_tool.func(
|
103
109
|
question=question, tool_call_id="test_call_id", state=state
|
@@ -114,8 +120,9 @@ def test_question_and_answer_tool_no_llm_model():
|
|
114
120
|
Test that an error is returned if the LLM model is missing in the state.
|
115
121
|
"""
|
116
122
|
state = {
|
117
|
-
"pdf_data": {"pdf_object": DUMMY_PDF_BYTES, "pdf_url": "http://dummy.url"}
|
118
|
-
#
|
123
|
+
"pdf_data": {"pdf_object": DUMMY_PDF_BYTES, "pdf_url": "http://dummy.url"},
|
124
|
+
"text_embedding_model": object(), # Added to avoid KeyError.
|
125
|
+
# llm_model is intentionally omitted.
|
119
126
|
}
|
120
127
|
question = "What is in the PDF?"
|
121
128
|
result = question_and_answer_tool.func(
|
@@ -124,6 +131,57 @@ def test_question_and_answer_tool_no_llm_model():
|
|
124
131
|
assert result == {"error": "No LLM model found in state."}
|
125
132
|
|
126
133
|
|
134
|
+
def test_generate_answer2_actual(monkeypatch):
|
135
|
+
"""
|
136
|
+
Test the actual behavior of generate_answer2 using fake dependencies
|
137
|
+
to exercise its internal logic.
|
138
|
+
"""
|
139
|
+
|
140
|
+
# Create a fake PyPDFLoader that does not perform a network call.
|
141
|
+
class FakePyPDFLoader:
|
142
|
+
"""class to fake PyPDFLoader"""
|
143
|
+
|
144
|
+
def __init__(self, file_path, headers=None):
|
145
|
+
"""Initialize the fake PyPDFLoader."""
|
146
|
+
self.file_path = file_path
|
147
|
+
self.headers = headers
|
148
|
+
|
149
|
+
def lazy_load(self):
|
150
|
+
"""Return a list with one fake Document."""
|
151
|
+
# Return a list with one fake Document.
|
152
|
+
return [Document(page_content="Answer for Test question?")]
|
153
|
+
|
154
|
+
monkeypatch.setattr(question_and_answer, "PyPDFLoader", FakePyPDFLoader)
|
155
|
+
|
156
|
+
# Create a fake vector store that returns a controlled result for similarity_search.
|
157
|
+
class FakeVectorStore:
|
158
|
+
"""Fake vector store for similarity search."""
|
159
|
+
|
160
|
+
def similarity_search(self, query):
|
161
|
+
"""Return a list with one Document containing our expected answer."""
|
162
|
+
# Return a list with one Document containing our expected answer.
|
163
|
+
return [Document(page_content=f"Answer for {query}")]
|
164
|
+
|
165
|
+
monkeypatch.setattr(
|
166
|
+
question_and_answer.InMemoryVectorStore,
|
167
|
+
"from_documents",
|
168
|
+
lambda docs, emb: FakeVectorStore(),
|
169
|
+
)
|
170
|
+
|
171
|
+
# Provide a dummy text embedding model.
|
172
|
+
dummy_text_embedding_model = object()
|
173
|
+
question = "Test question?"
|
174
|
+
pdf_url = "http://dummy.pdf"
|
175
|
+
|
176
|
+
# Call generate_answer2 without triggering an actual network call.
|
177
|
+
result = question_and_answer.generate_answer2(
|
178
|
+
question, pdf_url, dummy_text_embedding_model
|
179
|
+
)
|
180
|
+
# The function should join the page content from the similarity search.
|
181
|
+
expected = "Answer for Test question?"
|
182
|
+
assert result == expected
|
183
|
+
|
184
|
+
|
127
185
|
def test_generate_answer(monkeypatch):
|
128
186
|
"""
|
129
187
|
Test generate_answer function with controlled monkeypatched dependencies.
|
@@ -141,12 +199,15 @@ def test_generate_answer(monkeypatch):
|
|
141
199
|
"""
|
142
200
|
Fake Annoy.from_documents function that returns a fake vector store.
|
143
201
|
"""
|
202
|
+
|
144
203
|
# pylint: disable=too-few-public-methods, unused-argument
|
145
204
|
class FakeVectorStore:
|
146
205
|
"""Fake vector store for similarity search."""
|
206
|
+
|
147
207
|
def similarity_search(self, _question, k):
|
148
208
|
"""Return a list with a single dummy Document."""
|
149
209
|
return [Document(page_content="dummy content")]
|
210
|
+
|
150
211
|
return FakeVectorStore()
|
151
212
|
|
152
213
|
monkeypatch.setattr(
|
@@ -157,9 +218,11 @@ def test_generate_answer(monkeypatch):
|
|
157
218
|
"""
|
158
219
|
Fake load_qa_chain function that returns a fake QA chain.
|
159
220
|
"""
|
221
|
+
|
160
222
|
# pylint: disable=too-few-public-methods, unused-argument
|
161
223
|
class FakeChain:
|
162
224
|
"""Fake QA chain for testing generate_answer."""
|
225
|
+
|
163
226
|
def invoke(self, **kwargs):
|
164
227
|
"""
|
165
228
|
Fake invoke method that returns a mock answer.
|
@@ -169,6 +232,7 @@ def test_generate_answer(monkeypatch):
|
|
169
232
|
"answer": "real mock answer",
|
170
233
|
"question": input_data.get("question"),
|
171
234
|
}
|
235
|
+
|
172
236
|
return FakeChain()
|
173
237
|
|
174
238
|
monkeypatch.setattr(question_and_answer, "load_qa_chain", fake_load_qa_chain)
|
@@ -23,8 +23,7 @@ def mock_router():
|
|
23
23
|
|
24
24
|
def mock_supervisor_node(state):
|
25
25
|
query = state["messages"][-1].content.lower()
|
26
|
-
|
27
|
-
# Expanded keyword matching for S2 Agent
|
26
|
+
# Define keywords for each sub-agent.
|
28
27
|
s2_keywords = [
|
29
28
|
"paper",
|
30
29
|
"research",
|
@@ -34,13 +33,19 @@ def mock_router():
|
|
34
33
|
"references",
|
35
34
|
]
|
36
35
|
zotero_keywords = ["zotero", "library", "saved papers", "academic library"]
|
36
|
+
pdf_keywords = ["pdf", "document", "read pdf"]
|
37
|
+
paper_download_keywords = ["download", "arxiv", "fetch paper", "paper download"]
|
37
38
|
|
39
|
+
# Priority ordering: Zotero, then paper download, then PDF, then S2.
|
38
40
|
if any(keyword in query for keyword in zotero_keywords):
|
39
41
|
return Command(goto="zotero_agent")
|
42
|
+
if any(keyword in query for keyword in paper_download_keywords):
|
43
|
+
return Command(goto="paper_download_agent")
|
44
|
+
if any(keyword in query for keyword in pdf_keywords):
|
45
|
+
return Command(goto="pdf_agent")
|
40
46
|
if any(keyword in query for keyword in s2_keywords):
|
41
47
|
return Command(goto="s2_agent")
|
42
|
-
|
43
|
-
# If no match, default to ending the conversation
|
48
|
+
# Default to end if no keyword matches.
|
44
49
|
return Command(goto=END)
|
45
50
|
|
46
51
|
return mock_supervisor_node
|
@@ -55,10 +60,9 @@ def mock_router():
|
|
55
60
|
("Fetch my academic library.", "zotero_agent"),
|
56
61
|
("Retrieve citations.", "s2_agent"),
|
57
62
|
("Can you get journal articles?", "s2_agent"),
|
58
|
-
(
|
59
|
-
|
60
|
-
|
61
|
-
), # NEW: Should trigger the `END` case
|
63
|
+
("I want to read the PDF document.", "pdf_agent"),
|
64
|
+
("Download the paper from arxiv.", "paper_download_agent"),
|
65
|
+
("Completely unrelated query.", "__end__"),
|
62
66
|
],
|
63
67
|
)
|
64
68
|
def test_routing_logic(mock_state, mock_router, user_query, expected_agent):
|
@@ -94,7 +94,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
|
|
94
94
|
{
|
95
95
|
"paperId": "paperA",
|
96
96
|
"title": "Multi Rec Paper A",
|
97
|
-
"authors": ["Author X"],
|
97
|
+
"authors": [{"name": "Author X", "authorId": "AX"}],
|
98
98
|
"year": 2019,
|
99
99
|
"citationCount": 12,
|
100
100
|
"url": "http://paperA",
|
@@ -103,7 +103,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
|
|
103
103
|
{
|
104
104
|
"paperId": "paperB",
|
105
105
|
"title": "Multi Rec Paper B",
|
106
|
-
"authors": ["Author Y"],
|
106
|
+
"authors": [{"name": "Author Y", "authorId": "AY"}],
|
107
107
|
"year": 2020,
|
108
108
|
"citationCount": 18,
|
109
109
|
"url": "http://paperB",
|
@@ -112,7 +112,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
|
|
112
112
|
{
|
113
113
|
"paperId": "paperC",
|
114
114
|
"title": "Multi Rec Paper C",
|
115
|
-
"authors": None, # This
|
115
|
+
"authors": None, # This paper should be filtered out.
|
116
116
|
"year": 2021,
|
117
117
|
"citationCount": 25,
|
118
118
|
"url": "http://paperC",
|
@@ -277,6 +277,29 @@ def test_multi_paper_rec_requests_exception(monkeypatch):
|
|
277
277
|
}
|
278
278
|
with pytest.raises(
|
279
279
|
RuntimeError,
|
280
|
-
match="Failed to connect to Semantic Scholar API
|
280
|
+
match="Failed to connect to Semantic Scholar API after 10 attempts."
|
281
|
+
"Please retry the same query.",
|
282
|
+
):
|
283
|
+
get_multi_paper_recommendations.run(input_data)
|
284
|
+
|
285
|
+
|
286
|
+
def test_multi_paper_rec_no_response(monkeypatch):
|
287
|
+
"""
|
288
|
+
Test that get_multi_paper_recommendations raises a RuntimeError when no response is obtained.
|
289
|
+
This is simulated by patching 'range' in the underlying function's globals to
|
290
|
+
return an empty iterator,
|
291
|
+
so that the for loop does not iterate and response remains None.
|
292
|
+
"""
|
293
|
+
# Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
|
294
|
+
monkeypatch.setitem(
|
295
|
+
get_multi_paper_recommendations.func.__globals__, "range", lambda x: iter([])
|
296
|
+
)
|
297
|
+
tool_call_id = "test_tool_call_id"
|
298
|
+
input_data = {
|
299
|
+
"paper_ids": ["p1", "p2"],
|
300
|
+
"tool_call_id": tool_call_id,
|
301
|
+
}
|
302
|
+
with pytest.raises(
|
303
|
+
RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
|
281
304
|
):
|
282
305
|
get_multi_paper_recommendations.run(input_data)
|
@@ -85,7 +85,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
85
85
|
{
|
86
86
|
"paperId": "1",
|
87
87
|
"title": "Paper 1",
|
88
|
-
"authors": ["Author A"],
|
88
|
+
"authors": [{"name": "Author A", "authorId": "A1"}],
|
89
89
|
"year": 2020,
|
90
90
|
"citationCount": 10,
|
91
91
|
"url": "http://paper1",
|
@@ -94,7 +94,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
94
94
|
{
|
95
95
|
"paperId": "2",
|
96
96
|
"title": "Paper 2",
|
97
|
-
"authors": ["Author B"],
|
97
|
+
"authors": [{"name": "Author B", "authorId": "B1"}],
|
98
98
|
"year": 2021,
|
99
99
|
"citationCount": 20,
|
100
100
|
"url": "http://paper2",
|
@@ -256,7 +256,8 @@ def test_search_tool_requests_exception(monkeypatch):
|
|
256
256
|
tool_call_id = "test_tool_call_id"
|
257
257
|
with pytest.raises(
|
258
258
|
RuntimeError,
|
259
|
-
match="Failed to connect to Semantic Scholar API
|
259
|
+
match="Failed to connect to Semantic Scholar API after 10 attempts."
|
260
|
+
"Please retry the same query.",
|
260
261
|
):
|
261
262
|
search_tool.run(
|
262
263
|
{
|
@@ -264,3 +265,18 @@ def test_search_tool_requests_exception(monkeypatch):
|
|
264
265
|
"tool_call_id": tool_call_id,
|
265
266
|
}
|
266
267
|
)
|
268
|
+
|
269
|
+
|
270
|
+
def test_search_tool_no_response(monkeypatch):
|
271
|
+
"""
|
272
|
+
Test that search_tool raises a RuntimeError when no response is obtained.
|
273
|
+
This is simulated by patching 'range' in the original function's globals (a dict)
|
274
|
+
so that it returns an empty iterator, leaving response as None.
|
275
|
+
"""
|
276
|
+
# Patch 'range' in the original function's globals using setitem.
|
277
|
+
monkeypatch.setitem(search_tool.func.__globals__, "range", lambda x: iter([]))
|
278
|
+
tool_call_id = "test_tool_call_id"
|
279
|
+
with pytest.raises(
|
280
|
+
RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
|
281
|
+
):
|
282
|
+
search_tool.run({"query": "test", "tool_call_id": tool_call_id})
|